diff --git a/README.md b/README.md index a49c7e3f..320e6d3a 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ As well as: First, [download LibTorch](https://pytorch.org/get-started/locally/). For Mac arm64, use: ```sh -curl -L https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.9.0.zip > libtorch.zip +curl -L https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.9.1.zip > libtorch.zip unzip -q libtorch.zip ``` @@ -42,6 +42,8 @@ gem "torch-rb" It can take 5-10 minutes to compile the extension. Windows is not currently supported. +For distributed data parallel helpers, add the optional `torch-ddp` gem alongside this one. + ## Getting Started A good place to start is [Deep Learning with Torch.rb: A 60 Minute Blitz](tutorials/blitz/README.md). @@ -329,6 +331,8 @@ net.load_state_dict(Torch.load("net.pth")) net.eval ``` +`Torch.load` mirrors the Python API and accepts `map_location` and `weights_only` keyword arguments for compatibility with existing PyTorch checkpoints. + When saving a model in Python to load in Ruby, convert parameters to tensors (due to outstanding bugs in LibTorch) ```python diff --git a/bin/torchrun b/bin/torchrun new file mode 100755 index 00000000..d698ade2 --- /dev/null +++ b/bin/torchrun @@ -0,0 +1,6 @@ +#!/usr/bin/env ruby +# frozen_string_literal: true + +require_relative "../lib/torch/torchrun" + +Torch::TorchRun.start(ARGV) diff --git a/examples/benchmark/training.rb b/examples/benchmark/training.rb new file mode 100644 index 00000000..83b85520 --- /dev/null +++ b/examples/benchmark/training.rb @@ -0,0 +1,372 @@ +# Benchmark training throughput for common architectures/datasets. +# Usage examples: +# ruby examples/benchmark/training.rb --arch mnist_cnn --batch-size 128 --gpus 1 +# ruby examples/benchmark/training.rb --arch mnist_cnn --batch-size 128 --gpus 2 --steps 50 + +require "bundler/setup" +require "optparse" +require "torch" +require "torchvision" + +DEFAULT_BACKEND = if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available? + "nccl" +else + Torch::Distributed.get_default_backend_for_device(Torch::Accelerator.current_accelerator) || "gloo" +end +SPAWN_BACKEND_ENV = "TORCH_RB_BENCH_BACKEND".freeze +SPAWN_GROUP_ENV = "TORCH_RB_BENCH_GROUP_SIZE".freeze +SPAWN_BATCH_ENV = "TORCH_RB_BENCH_BATCH_SIZE".freeze + +def parse_list(value) + value.split(",").map(&:strip).reject(&:empty?) +end + +def backend_supported?(backend) + return true unless backend == "nccl" + + Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available? +end + +def usable_cuda_device_count + return 0 unless Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available? + + Torch::CUDA.respond_to?(:device_count) ? Torch::CUDA.device_count : 0 +rescue + 0 +end + +def spawn_worker_process? + ENV[Torch::Distributed::SPAWN_ENV_KEY] == "1" +end + +def apply_spawn_overrides!(options) + return unless ENV[Torch::Distributed::SPAWN_ENV_KEY] == "1" + + if ENV[SPAWN_BACKEND_ENV] + options[:backends] = [ENV[SPAWN_BACKEND_ENV]] + end + + if ENV[SPAWN_GROUP_ENV] + group_size = ENV[SPAWN_GROUP_ENV].to_i + if group_size.positive? + options[:group_sizes] = [group_size] + options[:gpus] = group_size + end + end + + if ENV[SPAWN_BATCH_ENV] + batch_size = ENV[SPAWN_BATCH_ENV].to_i + options[:batch_sizes] = [batch_size] if batch_size.positive? + end +end + +def with_spawn_env(backend:, group_size:, batch_size:) + previous = { + SPAWN_BACKEND_ENV => ENV[SPAWN_BACKEND_ENV], + SPAWN_GROUP_ENV => ENV[SPAWN_GROUP_ENV], + SPAWN_BATCH_ENV => ENV[SPAWN_BATCH_ENV] + } + + ENV[SPAWN_BACKEND_ENV] = backend + ENV[SPAWN_GROUP_ENV] = group_size.to_s + ENV[SPAWN_BATCH_ENV] = batch_size.to_s + + yield +ensure + ENV[SPAWN_BACKEND_ENV] = previous[SPAWN_BACKEND_ENV] + ENV[SPAWN_GROUP_ENV] = previous[SPAWN_GROUP_ENV] + ENV[SPAWN_BATCH_ENV] = previous[SPAWN_BATCH_ENV] +end + +class MnistCnn < Torch::NN::Module + def initialize + super() + @conv1 = Torch::NN::Conv2d.new(1, 32, 3, stride: 1) + @conv2 = Torch::NN::Conv2d.new(32, 64, 3, stride: 1) + @dropout1 = Torch::NN::Dropout2d.new(p: 0.25) + @dropout2 = Torch::NN::Dropout2d.new(p: 0.5) + @fc1 = Torch::NN::Linear.new(9216, 128) + @fc2 = Torch::NN::Linear.new(128, 10) + end + + def forward(x) + x = Torch::NN::F.relu(@conv1.call(x)) + x = Torch::NN::F.relu(@conv2.call(x)) + x = Torch::NN::F.max_pool2d(x, 2) + x = @dropout1.call(x) + x = Torch.flatten(x, start_dim: 1) + x = Torch::NN::F.relu(@fc1.call(x)) + x = @dropout2.call(x) + Torch::NN::F.log_softmax(@fc2.call(x), 1) + end +end + +ARCH_CONFIGS = { + "mnist_cnn" => { + model: -> { MnistCnn.new }, + dataset: :mnist + } +}.freeze + +def parse_options + defaults = { + arch: "mnist_cnn", + batch_sizes: [128], + steps: 100, + warmup: 10, + backends: [DEFAULT_BACKEND], + gpus: Torch::CUDA.available? ? [Torch::CUDA.device_count, 1].max : 1, + group_sizes: nil, + data_dir: File.join(__dir__, "data"), + lr: 0.01 + } + + OptionParser.new do |opts| + opts.banner = "Usage: ruby examples/benchmark/training.rb [options]" + opts.on("--arch NAME", "Architecture to benchmark (#{ARCH_CONFIGS.keys.join(', ')}, default: #{defaults[:arch]})") { |v| defaults[:arch] = v } + opts.on("--batch-size N", Integer, "Batch size per process (default: #{defaults[:batch_sizes].first})") { |v| defaults[:batch_sizes] = [v] } + opts.on("--batch-sizes LIST", String, "Comma-separated batch sizes per process") { |v| defaults[:batch_sizes] = parse_list(v).map(&:to_i) } + opts.on("--steps N", Integer, "Number of timed training steps (default: #{defaults[:steps]})") { |v| defaults[:steps] = v } + opts.on("--warmup N", Integer, "Number of warmup steps not included in timing (default: #{defaults[:warmup]})") { |v| defaults[:warmup] = v } + opts.on("--backend NAME", String, "Process group backend (default: #{defaults[:backends].first})") { |v| defaults[:backends] = [v] } + opts.on("--backends LIST", String, "Comma-separated list of backends to benchmark (gloo,nccl)") { |v| defaults[:backends] = parse_list(v) } + opts.on("--gpus N", Integer, "Number of GPUs/processes to use (1 for non-distributed)") { |v| defaults[:gpus] = v } + opts.on("--group-sizes LIST", String, "Process group sizes to benchmark (default: 1..gpus)") { |v| defaults[:group_sizes] = parse_list(v).map(&:to_i) } + opts.on("--data-dir PATH", String, "Directory for cached datasets (default: #{defaults[:data_dir]})") { |v| defaults[:data_dir] = v } + opts.on("--lr FLOAT", Float, "Learning rate (default: #{defaults[:lr]})") { |v| defaults[:lr] = v } + end.parse!(ARGV) + + defaults[:group_sizes] ||= (1..defaults[:gpus]).to_a + defaults +end + +def dataset_for(name, data_dir, distributed:, rank:, world_size:) + case name + when :mnist + transforms = TorchVision::Transforms::Compose.new([ + TorchVision::Transforms::ToTensor.new, + TorchVision::Transforms::Normalize.new([0.1307], [0.3081]) + ]) + + if distributed + if rank.zero? + train = TorchVision::Datasets::MNIST.new(data_dir, train: true, download: true, transform: transforms) + Torch::Distributed.barrier + else + Torch::Distributed.barrier + train = TorchVision::Datasets::MNIST.new(data_dir, train: true, download: false, transform: transforms) + end + indices = rank.step(train.size - 1, world_size).to_a + Torch::Utils::Data::Subset.new(train, indices) + else + TorchVision::Datasets::MNIST.new(data_dir, train: true, download: true, transform: transforms) + end + else + raise ArgumentError, "Unknown dataset: #{name}" + end +end + +def sync_cuda_if_needed(device) + return unless device && device.type == "cuda" + return unless Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:synchronize) + + Torch::CUDA.synchronize +end + +def benchmark_worker(rank, world_size, port, options) + arch = options.fetch(:arch) + config = ARCH_CONFIGS[arch] + raise ArgumentError, "Unsupported architecture #{arch.inspect}" unless config + + distributed = world_size > 1 + accelerator = Torch::Accelerator.current_accelerator + selected_backend = options[:backend] || Torch::Distributed.get_default_backend_for_device(accelerator) || DEFAULT_BACKEND + if distributed + store = Torch::Distributed::TCPStore.new("127.0.0.1", port, world_size, rank.zero?) + Torch::Distributed.init_process_group(selected_backend, store: store, rank: rank, world_size: world_size) + end + + cuda_devices = usable_cuda_device_count + device = if cuda_devices.positive? && options[:gpus] > 0 + Torch.device("cuda:#{rank % cuda_devices}") + else + Torch.device("cpu") + end + + model = config[:model].call.to(device) + if distributed + ddp_devices = device.type == "cuda" ? [device.index] : nil + model = Torch::NN::Parallel::DistributedDataParallel.new(model, device_ids: ddp_devices) + end + optimizer = Torch::Optim::SGD.new(model.parameters, lr: options[:lr]) + + loader = Torch::Utils::Data::DataLoader.new( + dataset_for(config[:dataset], options[:data_dir], distributed: distributed, rank: rank, world_size: world_size), + batch_size: options[:batch_size], + shuffle: true + ) + + warmup_steps = options[:warmup] + timed_steps = options[:steps] + total_steps = warmup_steps + timed_steps + losses = [] + + # Warm up the model (including one full timed-length pass) to avoid init overhead in measurements. + step_idx = 0 + loader.each do |data, target| + data = data.to(device) + target = target.to(device) + + optimizer.zero_grad + loss = Torch::NN::F.nll_loss(model.call(data), target) + loss.backward + optimizer.step + + step_idx += 1 + break if step_idx >= total_steps + end + + sync_cuda_if_needed(device) + Torch::Distributed.barrier if distributed + + timed = 0 + step_idx = 0 + start = Process.clock_gettime(Process::CLOCK_MONOTONIC) + loader.each do |data, target| + data = data.to(device) + target = target.to(device) + + optimizer.zero_grad + loss = Torch::NN::F.nll_loss(model.call(data), target) + loss.backward + optimizer.step + + loss_value = loss.item + if distributed + loss_tensor = Torch.tensor([loss_value], device: device) + Torch::Distributed.all_reduce(loss_tensor) + loss_value = loss_tensor.item / world_size.to_f + end + losses << loss_value if !distributed || rank.zero? + + step_idx += 1 + break if step_idx >= timed_steps + end + + sync_cuda_if_needed(device) + Torch::Distributed.barrier if distributed + elapsed = Process.clock_gettime(Process::CLOCK_MONOTONIC) - start + timed = step_idx + + images = timed * options[:batch_size] * world_size + throughput = elapsed.positive? ? images.to_f / elapsed : 0.0 + initial_loss = losses.first || 0.0 + final_loss = losses.last || initial_loss + loss_delta = initial_loss - final_loss + loss_delta_per_step = timed.zero? ? 0.0 : loss_delta / timed + loss_delta_per_sec = elapsed.zero? ? 0.0 : loss_delta / elapsed + + result = if !distributed || rank.zero? + { + backend: selected_backend, + world_size: world_size, + batch_size: options[:batch_size], + arch: arch, + dataset: config[:dataset], + elapsed: elapsed, + timed_steps: timed, + images: images, + throughput: throughput, + initial_loss: initial_loss, + final_loss: final_loss, + loss_delta: loss_delta, + loss_delta_per_step: loss_delta_per_step, + loss_delta_per_sec: loss_delta_per_sec + } + end + + Torch::Distributed.destroy_process_group if distributed + result +end + +def run_benchmark_case(world_size, options) + if world_size > 1 + outputs = Torch::Distributed.fork_world(world_size, start_method: :spawn) do |rank, port| + benchmark_worker(rank, world_size, port, options) + end + outputs.compact.first + else + benchmark_worker(0, 1, Torch::Distributed.free_port, options) + end +end + +def print_summary_table(results) + puts "\nBenchmark comparison (processing vs convergence)" + puts "Processing speed: images per second. Convergence speed: average loss reduction per step and per second.\n" + + headers = ["Backend", "Proc Group", "Batch", "Images/s", "Loss delta/step", "Loss delta/s", "Final loss"] + formatters = [ + ->(r) { r[:backend] }, + ->(r) { r[:world_size] }, + ->(r) { r[:batch_size] }, + ->(r) { format("%.1f", r[:throughput]) }, + ->(r) { format("%.4f", r[:loss_delta_per_step]) }, + ->(r) { format("%.4f", r[:loss_delta_per_sec]) }, + ->(r) { format("%.4f", r[:final_loss]) } + ] + + widths = headers.each_with_index.map do |header, idx| + [header.length, results.map { |r| formatters[idx].call(r).to_s.length }.max].compact.max + end + + header_line = headers.each_with_index.map { |h, idx| h.ljust(widths[idx]) }.join(" | ") + divider = widths.map { |w| "-" * w }.join("-+-") + puts header_line + puts divider + + results.sort_by { |r| [r[:backend], r[:world_size], r[:batch_size]] }.each do |result| + row = formatters.each_with_index.map { |formatter, idx| formatter.call(result).to_s.ljust(widths[idx]) } + puts row.join(" | ") + end +end + +options = parse_options +apply_spawn_overrides!(options) +max_world_size = options[:gpus] +raise "Number of GPUs requested must be >= 1" if max_world_size < 1 +Torch.manual_seed(1) + +group_sizes = options[:group_sizes].map { |v| [v, max_world_size].min }.select { |v| v >= 1 }.uniq.sort +batch_sizes = options[:batch_sizes].map { |v| [v, 1].max }.uniq +backends = options[:backends].map(&:downcase).uniq + +if group_sizes.any? { |size| size > 1 } + raise "torch.distributed is not available" unless Torch::Distributed.available? +end + +results = [] + +backends.each do |backend| + unless backend_supported?(backend) + warn "Skipping backend=#{backend} because required accelerator support is unavailable." + next + end + + group_sizes.each do |world_size| + batch_sizes.each do |batch_size| + run_options = options.merge(batch_size: batch_size, backend: backend, gpus: world_size) + puts "Running backend=#{backend}, group_size=#{world_size}, batch_size=#{batch_size}..." unless spawn_worker_process? + with_spawn_env(backend: backend, group_size: world_size, batch_size: batch_size) do + results << run_benchmark_case(world_size, run_options) + end + end + end +end + +results.compact! + +if results.empty? + puts "No benchmark results to report." +else + print_summary_table(results) +end diff --git a/examples/mnist/distributed.rb b/examples/mnist/distributed.rb new file mode 100644 index 00000000..d300ead0 --- /dev/null +++ b/examples/mnist/distributed.rb @@ -0,0 +1,238 @@ +# Distributed MNIST training with Torch::Distributed + DistributedDataParallel +# Run with: ruby examples/mnist/distributed.rb --gpus 2 + +require "bundler/setup" +require "optparse" +require "torch" +require "torchvision" +require "tmpdir" + +unless Torch::Distributed.available? + abort "torch.distributed was not built in this binary" +end + +DEFAULT_CHECKPOINT_PATH = File.join(Dir.tmpdir, "mnist_ddp_checkpoint.pt") +DEFAULT_BACKEND = if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available? + "nccl" +else + Torch::Distributed.get_default_backend_for_device(Torch::Accelerator.current_accelerator) || "gloo" +end + +class MyNet < Torch::NN::Module + def initialize + super() + @conv1 = Torch::NN::Conv2d.new(1, 32, 3, stride: 1) + @conv2 = Torch::NN::Conv2d.new(32, 64, 3, stride: 1) + @dropout1 = Torch::NN::Dropout2d.new(p: 0.25) + @dropout2 = Torch::NN::Dropout2d.new(p: 0.5) + @fc1 = Torch::NN::Linear.new(9216, 128) + @fc2 = Torch::NN::Linear.new(128, 10) + end + + def forward(x) + x = Torch::NN::F.relu(@conv1.call(x)) + x = Torch::NN::F.relu(@conv2.call(x)) + x = Torch::NN::F.max_pool2d(x, 2) + x = @dropout1.call(x) + x = Torch.flatten(x, start_dim: 1) + x = Torch::NN::F.relu(@fc1.call(x)) + x = @dropout2.call(x) + Torch::NN::F.log_softmax(@fc2.call(x), 1) + end +end + +def parse_options + defaults = { + epochs: 5, + batch_size: 64, + lr: 1.0, + gamma: 0.7, + backend: DEFAULT_BACKEND, + gpus: Torch::CUDA.available? ? [Torch::CUDA.device_count, 1].max : 1, + log_interval: 20, + data_dir: File.join(__dir__, "data"), + checkpoint_path: DEFAULT_CHECKPOINT_PATH, + resume: false + } + + OptionParser.new do |opts| + opts.banner = "Usage: ruby distributed.rb [options]" + opts.on("--epochs N", Integer, "Number of epochs (default: #{defaults[:epochs]})") { |v| defaults[:epochs] = v } + opts.on("--batch-size N", Integer, "Batch size per process (default: #{defaults[:batch_size]})") { |v| defaults[:batch_size] = v } + opts.on("--lr FLOAT", Float, "Learning rate (default: #{defaults[:lr]})") { |v| defaults[:lr] = v } + opts.on("--gamma FLOAT", Float, "LR scheduler gamma (default: #{defaults[:gamma]})") { |v| defaults[:gamma] = v } + opts.on("--backend NAME", String, "Process group backend (default: #{defaults[:backend]})") { |v| defaults[:backend] = v } + opts.on("--gpus N", Integer, "Number of GPUs/processes to use") { |v| defaults[:gpus] = v } + opts.on("--log-interval N", Integer, "Batches between log statements") { |v| defaults[:log_interval] = v } + opts.on("--data-dir PATH", String, "Directory for cached MNIST data") { |v| defaults[:data_dir] = v } + opts.on("--checkpoint PATH", String, "Checkpoint file to save to (default: #{defaults[:checkpoint_path]})") { |v| defaults[:checkpoint_path] = v } + opts.on("--resume", "Load checkpoint weights before training if the file exists") { defaults[:resume] = true } + end.parse!(ARGV) + + defaults +end + +def load_datasets(rank, data_dir) + transforms = TorchVision::Transforms::Compose.new([ + TorchVision::Transforms::ToTensor.new, + TorchVision::Transforms::Normalize.new([0.1307], [0.3081]) + ]) + + if rank.zero? + train = TorchVision::Datasets::MNIST.new(data_dir, train: true, download: true, transform: transforms) + test = TorchVision::Datasets::MNIST.new(data_dir, train: false, download: true, transform: transforms) + Torch::Distributed.barrier + else + Torch::Distributed.barrier + train = TorchVision::Datasets::MNIST.new(data_dir, train: true, download: false, transform: transforms) + test = TorchVision::Datasets::MNIST.new(data_dir, train: false, download: false, transform: transforms) + end + + [train, test] +end + +def subset_for_rank(dataset, rank, world_size) + indices = rank.step(dataset.size - 1, world_size).to_a + Torch::Utils::Data::Subset.new(dataset, indices) +end + +def checkpoint_map_location(device, rank) + accelerator_device = Torch::Accelerator.current_accelerator + return nil unless accelerator_device + + accelerator_type = accelerator_device.type + target_index = device.index + if target_index.nil? && Torch::Accelerator.respond_to?(:device_count) + count = Torch::Accelerator.device_count + target_index = count.positive? ? rank % count : 0 + end + { "#{accelerator_type}:0" => "#{accelerator_type}:#{target_index}" } +end + +def load_checkpoint_if_present(ddp, device, rank, path) + return false unless path && File.exist?(path) + + Torch::Distributed.barrier + kwargs = { weights_only: true } + map_location = checkpoint_map_location(device, rank) + kwargs[:map_location] = map_location if map_location + state_dict = Torch.load(path, **kwargs) + ddp.module.load_state_dict(state_dict) + true +end + +def save_checkpoint(ddp, path, rank) + return unless path + + Torch.save(ddp.module.state_dict, path) if rank.zero? + Torch::Distributed.barrier + puts "Saved checkpoint to #{path}" if rank.zero? +end + +def train_epoch(model, device, loader, optimizer, epoch, rank, log_interval) + model.train + loader.each_with_index do |(data, target), batch_idx| + data = data.to(device) + target = target.to(device) + + optimizer.zero_grad + loss = Torch::NN::F.nll_loss(model.call(data), target) + loss.backward + optimizer.step + + next unless rank.zero? && (batch_idx % log_interval).zero? + + processed = batch_idx * data.size(0) + total = loader.dataset.size + percent = 100.0 * processed / total + puts "Rank #{rank} | Epoch #{epoch} [#{processed}/#{total} (#{percent.round})%] Loss: #{'%.4f' % loss.item}" + end +end + +def evaluate(model, device, loader) + model.eval + loss = 0.0 + correct = 0 + Torch.no_grad do + loader.each do |data, target| + data = data.to(device) + target = target.to(device) + output = model.call(data) + loss += Torch::NN::F.nll_loss(output, target, reduction: "sum").item + pred = output.argmax(1, keepdim: true) + correct += pred.eq(target.view_as(pred)).sum.item + end + end + + loss /= loader.dataset.size + acc = 100.0 * correct / loader.dataset.size + puts "Test set: Average loss: #{format('%.4f', loss)}, Accuracy: #{correct}/#{loader.dataset.size} (#{format('%.1f', acc)}%)" +end + +def run_worker(rank, world_size, port, options) + store = Torch::Distributed::TCPStore.new("127.0.0.1", port, world_size, rank.zero?) + accelerator = Torch::Accelerator.current_accelerator + backend = options[:backend] || Torch::Distributed.get_default_backend_for_device(accelerator) || DEFAULT_BACKEND + Torch::Distributed.init_process_group(backend, store: store, rank: rank, world_size: world_size) + + device = if Torch::CUDA.available? && options[:gpus] > 0 + Torch.device("cuda:#{rank % Torch::CUDA.device_count}") + else + Torch.device("cpu") + end + + model = MyNet.new.to(device) + ddp = Torch::NN::Parallel::DistributedDataParallel.new(model, device_ids: device.type == "cuda" ? [device.index] : nil) + optimizer = Torch::Optim::Adadelta.new(ddp.module.parameters, lr: options[:lr]) + scheduler = Torch::Optim::LRScheduler::StepLR.new(optimizer, step_size: 1, gamma: options[:gamma]) + + train_dataset, test_dataset = load_datasets(rank, options[:data_dir]) + train_subset = subset_for_rank(train_dataset, rank, world_size) + train_loader = Torch::Utils::Data::DataLoader.new(train_subset, batch_size: options[:batch_size], shuffle: true) + test_loader = Torch::Utils::Data::DataLoader.new(test_dataset, batch_size: options[:batch_size], shuffle: false) if rank.zero? + checkpoint_path = options[:checkpoint_path] + + if options[:resume] + loaded = load_checkpoint_if_present(ddp, device, rank, checkpoint_path) + if rank.zero? + if loaded + puts "Loaded checkpoint weights from #{checkpoint_path}" + else + puts "No checkpoint found at #{checkpoint_path}, starting from random initialization" + end + end + end + + options[:epochs].times do |epoch_idx| + epoch = epoch_idx + 1 + train_epoch(ddp, device, train_loader, optimizer, epoch, rank, options[:log_interval]) + if rank.zero? + evaluate(ddp.module, device, test_loader) + end + save_checkpoint(ddp, checkpoint_path, rank) if checkpoint_path + end + + Torch::Distributed.destroy_process_group +end + +options = parse_options +world_size = options[:gpus] +raise "Number of GPUs requested must be >= 1" if world_size < 1 +if Torch::CUDA.available? + max_devices = Torch::CUDA.device_count + if world_size > max_devices + raise "Requested #{world_size} GPUs but only #{max_devices} visible" + end +else + puts "CUDA not available, running #{world_size} CPU workers" +end + +Torch.manual_seed(1) + +if world_size == 1 + run_worker(0, 1, Torch::Distributed.free_port, options) +else + Torch::Distributed.fork_world(world_size, start_method: :spawn) do |rank, port| + run_worker(rank, world_size, port, options) + end +end diff --git a/ext/torch/accelerator.cpp b/ext/torch/accelerator.cpp new file mode 100644 index 00000000..45cfcb41 --- /dev/null +++ b/ext/torch/accelerator.cpp @@ -0,0 +1,52 @@ +#include +#include +#include + +#include + +#include "utils.h" + +namespace { + +inline bool accelerator_available(c10::DeviceType device_type) { + return at::globalContext() + .getAcceleratorHooksInterface(device_type) + .isAvailable(); +} + +} // namespace + +void init_accelerator(Rice::Module& m) { + auto rb_mAccelerator = Rice::define_module_under(m, "Accelerator"); + + rb_mAccelerator.define_singleton_function( + "_current_device", + []() -> VALUE { + auto acc = at::getAccelerator(false); + if (!acc.has_value()) { + return Rice::Nil; + } + torch::Device device(acc.value()); + return Rice::detail::To_Ruby().convert(device); + }); + + rb_mAccelerator.define_singleton_function( + "_is_available", + []() { + auto acc = at::getAccelerator(false); + if (!acc.has_value()) { + return false; + } + return accelerator_available(acc.value()); + }); + + rb_mAccelerator.define_singleton_function( + "_device_count", + []() { + auto acc = at::getAccelerator(false); + if (!acc.has_value()) { + return 0; + } + return static_cast(at::accelerator::deviceCount()); + }); +} diff --git a/ext/torch/cuda.cpp b/ext/torch/cuda.cpp index 23f38d80..5a7c52e4 100644 --- a/ext/torch/cuda.cpp +++ b/ext/torch/cuda.cpp @@ -1,13 +1,36 @@ #include +#ifdef HAVE_C10_CUDA +#include +#endif #include #include "utils.h" void init_cuda(Rice::Module& m) { - Rice::define_module_under(m, "CUDA") + auto rb_mCUDA = Rice::define_module_under(m, "CUDA"); + + rb_mCUDA .define_singleton_function("available?", &torch::cuda::is_available) .define_singleton_function("device_count", &torch::cuda::device_count) .define_singleton_function("manual_seed", &torch::cuda::manual_seed) .define_singleton_function("manual_seed_all", &torch::cuda::manual_seed_all); + +#ifdef HAVE_C10_CUDA + rb_mCUDA.define_singleton_function( + "set_device", + [](int device_id) { + c10::cuda::set_device(device_id); + return Rice::Nil; + }); +#else + rb_mCUDA.define_singleton_function( + "set_device", + [](int) { + rb_raise( + rb_eRuntimeError, + "c10 CUDA support is not available in this build; set_device cannot be used"); + return Rice::Nil; + }); +#endif } diff --git a/ext/torch/distributed.cpp b/ext/torch/distributed.cpp new file mode 100644 index 00000000..a3d50680 --- /dev/null +++ b/ext/torch/distributed.cpp @@ -0,0 +1,344 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#if defined(USE_C10D) && defined(USE_C10D_NCCL) +#include +#include +#endif + +#include +#include + +#include "utils.h" + +#ifdef USE_C10D +#include +#include +#include +#include +#include +#include +#include +#endif + +#if defined(USE_C10D) && defined(USE_C10D_NCCL) +#include +#endif + +#if defined(USE_C10D) && !defined(_WIN32) +#include +#endif + +namespace { + +#ifdef USE_C10D + +using StorePtr = c10::intrusive_ptr<::c10d::Store>; +using ProcessGroupPtr = c10::intrusive_ptr<::c10d::Backend>; + +struct StoreWrapper { + StoreWrapper() = default; + explicit StoreWrapper(StorePtr store) : store_(std::move(store)) {} + + StorePtr store_; +}; + +struct ProcessGroupWrapper { + ProcessGroupWrapper() = default; + explicit ProcessGroupWrapper(ProcessGroupPtr pg) : pg_(std::move(pg)) {} + + ProcessGroupPtr pg_; +}; + +ProcessGroupPtr default_process_group; +std::once_flag default_pg_cleanup_once; + +void shutdown_default_process_group() { + if (default_process_group) { + try { + default_process_group->shutdown(); + } catch (...) { + // best effort; ensure reset still happens + } + default_process_group.reset(); + } +} + +void register_default_pg_cleanup() { + std::call_once(default_pg_cleanup_once, []() { + std::atexit([]() { shutdown_default_process_group(); }); + }); +} + +ProcessGroupPtr resolve_process_group(Rice::Object pg_obj) { + if (pg_obj.is_nil()) { + if (!default_process_group) { + rb_raise(rb_eRuntimeError, "Distributed process group not initialized"); + } + return default_process_group; + } + auto& wrapper = Rice::detail::From_Ruby().convert(pg_obj.value()); + if (!wrapper.pg_) { + rb_raise(rb_eRuntimeError, "Invalid process group"); + } + return wrapper.pg_; +} + +int reduce_op_from_int(int code) { + if (code < 0 || code > static_cast(::c10d::ReduceOp::UNUSED)) { + rb_raise(rb_eArgError, "Unknown reduce op code"); + } + return code; +} + +#endif + +} // namespace + +void init_distributed(Rice::Module& m) { + auto rb_mDistributed = Rice::define_module_under(m, "Distributed"); +#ifdef USE_C10D + register_default_pg_cleanup(); + rb_mDistributed.define_singleton_function("available?", []() { return true; }); + + auto rb_cStore = Rice::define_class_under(rb_mDistributed, "Store"); + rb_cStore.define_method( + "_native?", + [](StoreWrapper& self) { + return static_cast(self.store_); + }); + + auto rb_cProcessGroup = Rice::define_class_under(rb_mDistributed, "ProcessGroup") + .define_method( + "rank", + [](ProcessGroupWrapper& self) { + return self.pg_ ? self.pg_->getRank() : -1; + }) + .define_method( + "size", + [](ProcessGroupWrapper& self) { + return self.pg_ ? self.pg_->getSize() : 0; + }) + .define_method( + "backend", + [](ProcessGroupWrapper& self) { + if (!self.pg_) { + return std::string(); + } + return self.pg_->getBackendName(); + }); + + rb_mDistributed.define_singleton_function( + "_create_tcp_store", + [rb_cStore](const std::string& host, + int port, + int world_size, + bool is_master, + int64_t timeout_millis, + bool wait_for_workers) -> Rice::Object { + ::c10d::TCPStoreOptions opts; + opts.port = static_cast(port); + opts.isServer = is_master; + opts.numWorkers = world_size; + opts.waitWorkers = wait_for_workers; + opts.timeout = std::chrono::milliseconds(timeout_millis); + auto store = c10::make_intrusive<::c10d::TCPStore>(host, opts); + // Pass ownership first, then the Ruby class so Rice doesn't treat the class as the owner flag + return Rice::Data_Object(new StoreWrapper(store), true, rb_cStore); + }); + + rb_mDistributed.define_singleton_function( + "_create_file_store", + [rb_cStore](const std::string& path, int world_size) -> Rice::Object { + auto store = c10::make_intrusive<::c10d::FileStore>(path, world_size); + return Rice::Data_Object(new StoreWrapper(store), true, rb_cStore); + }); + +#if !defined(_WIN32) + rb_mDistributed.define_singleton_function( + "_create_hash_store", + [rb_cStore]() -> Rice::Object { + auto store = c10::make_intrusive<::c10d::HashStore>(); + return Rice::Data_Object(new StoreWrapper(store), true, rb_cStore); + }); +#endif + + rb_mDistributed.define_singleton_function( + "_init_process_group", + [rb_cProcessGroup](const std::string& backend, + StoreWrapper& store_wrapper, + int rank, + int world_size, + int64_t timeout_millis, + int device_id) -> Rice::Object { + StorePtr store = store_wrapper.store_; + if (!store) { + rb_raise(rb_eArgError, "Store is required for init_process_group"); + } + + std::string backend_lower = backend; + std::transform(backend_lower.begin(), backend_lower.end(), backend_lower.begin(), ::tolower); + + ProcessGroupPtr pg; + if (backend_lower == "gloo") { +#ifdef USE_C10D_GLOO + auto options = ::c10d::ProcessGroupGloo::Options::create(); + options->timeout = std::chrono::milliseconds(timeout_millis); + options->devices.push_back(::c10d::ProcessGroupGloo::createDefaultDevice()); + pg = c10::make_intrusive<::c10d::ProcessGroupGloo>(store, rank, world_size, options); +#else + rb_raise(rb_eRuntimeError, "Gloo backend is not available in this build"); +#endif + } else if (backend_lower == "nccl") { +#if defined(USE_C10D_NCCL) + auto options = c10::make_intrusive<::c10d::ProcessGroupNCCL::Options>(); + options->timeout = std::chrono::milliseconds(timeout_millis); + pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>(store, rank, world_size, options); +#else + rb_raise(rb_eRuntimeError, "NCCL backend is not available in this build"); +#endif + } else { + rb_raise(rb_eArgError, "Unsupported backend: %s", backend.c_str()); + } + + if (device_id >= 0 && backend_lower == "nccl") { +#if defined(USE_C10D_NCCL) + if (!torch::cuda::is_available()) { + rb_raise(rb_eRuntimeError, "CUDA is not available for NCCL backend"); + } + auto device_count = torch::cuda::device_count(); + if (device_id >= static_cast(device_count)) { + rb_raise( + rb_eArgError, + "Invalid device_id %d for NCCL backend (available devices: %d)", + device_id, + static_cast(device_count)); + } + c10::cuda::set_device(device_id); + pg->setBoundDeviceId(c10::Device(c10::kCUDA, device_id)); +#endif + } + + default_process_group = pg; + return Rice::Data_Object(new ProcessGroupWrapper(pg), true, rb_cProcessGroup); + }); + + rb_mDistributed.define_singleton_function( + "_destroy_process_group", + []() { + shutdown_default_process_group(); + return Rice::Nil; + }); + + rb_mDistributed.define_singleton_function( + "_initialized?", + []() { + return static_cast(default_process_group); + }); + + rb_mDistributed.define_singleton_function( + "_default_process_group", + [rb_cProcessGroup]() -> Rice::Object { + if (!default_process_group) { + return Rice::Nil; + } + return Rice::Data_Object(new ProcessGroupWrapper(default_process_group), true, rb_cProcessGroup); + }); + + rb_mDistributed.define_singleton_function( + "_get_world_size", + [](Rice::Object pg_obj) { + auto pg = resolve_process_group(pg_obj); + return pg->getSize(); + }); + + rb_mDistributed.define_singleton_function( + "_get_rank", + [](Rice::Object pg_obj) { + auto pg = resolve_process_group(pg_obj); + return pg->getRank(); + }); + + rb_mDistributed.define_singleton_function( + "_barrier", + [](Rice::Object pg_obj) { + auto pg = resolve_process_group(pg_obj); + ::c10d::BarrierOptions opts; + auto work = pg->barrier(opts); + work->wait(); + return Rice::Nil; + }); + + rb_mDistributed.define_singleton_function( + "_all_reduce", + [](torch::Tensor& tensor, int op_code, Rice::Object pg_obj) { + auto pg = resolve_process_group(pg_obj); + ::c10d::AllreduceOptions opts; + opts.reduceOp = ::c10d::ReduceOp(static_cast<::c10d::ReduceOp::RedOpType>(reduce_op_from_int(op_code))); + std::vector tensors{tensor}; + auto work = pg->allreduce(tensors, opts); + work->wait(); + return tensor; + }); + + rb_mDistributed.define_singleton_function( + "_broadcast", + [](torch::Tensor& tensor, int src, Rice::Object pg_obj) { + auto pg = resolve_process_group(pg_obj); + ::c10d::BroadcastOptions opts; + opts.rootRank = src; + std::vector tensors{tensor}; + auto work = pg->broadcast(tensors, opts); + work->wait(); + return tensor; + }); + + rb_mDistributed.define_singleton_function( + "_register_ddp_hook", + [](torch::Tensor& tensor, ProcessGroupWrapper& pg_wrapper, int world_size) -> unsigned { + if (!pg_wrapper.pg_) { + rb_raise(rb_eArgError, "Process group is required for DDP hook registration"); + } + if (world_size <= 0) { + rb_raise(rb_eArgError, "world_size must be positive"); + } + + auto pg = pg_wrapper.pg_; + // Register a native autograd hook that all-reduces gradients and scales + // them by the world size. This avoids calling back into Ruby from + // autograd worker threads. + unsigned handle = tensor.register_hook([pg, world_size](const at::Tensor& grad) { + ::c10d::AllreduceOptions opts; + opts.reduceOp = ::c10d::ReduceOp::SUM; + std::vector tensors{grad}; + auto work = pg->allreduce(tensors, opts); + work->wait(); + grad.div_(static_cast(world_size)); + return grad; + }); + + return handle; + }); + + auto rb_mReduceOp = Rice::define_module_under(rb_mDistributed, "ReduceOp"); + rb_mReduceOp.const_set("SUM", INT2NUM(static_cast(::c10d::ReduceOp::SUM))); + rb_mReduceOp.const_set("AVG", INT2NUM(static_cast(::c10d::ReduceOp::AVG))); + rb_mReduceOp.const_set("PRODUCT", INT2NUM(static_cast(::c10d::ReduceOp::PRODUCT))); + rb_mReduceOp.const_set("MIN", INT2NUM(static_cast(::c10d::ReduceOp::MIN))); + rb_mReduceOp.const_set("MAX", INT2NUM(static_cast(::c10d::ReduceOp::MAX))); + rb_mReduceOp.const_set("BAND", INT2NUM(static_cast(::c10d::ReduceOp::BAND))); + rb_mReduceOp.const_set("BOR", INT2NUM(static_cast(::c10d::ReduceOp::BOR))); + rb_mReduceOp.const_set("BXOR", INT2NUM(static_cast(::c10d::ReduceOp::BXOR))); + rb_mReduceOp.const_set("PREMUL_SUM", INT2NUM(static_cast(::c10d::ReduceOp::PREMUL_SUM))); + + rb_mDistributed.const_set("DEFAULT_TIMEOUT", INT2NUM(::kProcessGroupDefaultTimeout.count() / 1000)); +#else + rb_mDistributed.define_singleton_function("available?", []() { return false; }); +#endif +} diff --git a/ext/torch/ext.cpp b/ext/torch/ext.cpp index eb6fb7d3..c07528b8 100644 --- a/ext/torch/ext.cpp +++ b/ext/torch/ext.cpp @@ -6,6 +6,7 @@ void init_fft(Rice::Module& m); void init_linalg(Rice::Module& m); void init_nn(Rice::Module& m); void init_special(Rice::Module& m); +void init_accelerator(Rice::Module& m); void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions); void init_torch(Rice::Module& m); @@ -40,6 +41,7 @@ void Init_ext() { init_fft(m); init_linalg(m); init_special(m); + init_accelerator(m); init_backends(m); init_cuda(m); diff --git a/ext/torch/extconf.rb b/ext/torch/extconf.rb index cf3c6706..0032088d 100644 --- a/ext/torch/extconf.rb +++ b/ext/torch/extconf.rb @@ -47,6 +47,7 @@ with_cuda = false if Dir["#{lib}/*torch_cuda*"].any? $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib) + $INCFLAGS += " -I#{cuda_inc}" if cuda_inc && Dir.exist?(cuda_inc) $LDFLAGS += " -L#{cudnn_lib}" if Dir.exist?(cudnn_lib) && cudnn_lib != cuda_lib with_cuda = have_library("cuda") && have_library("cudnn") end @@ -54,6 +55,24 @@ $INCFLAGS += " -I#{inc}" $INCFLAGS += " -I#{inc}/torch/csrc/api/include" +CONFIG["CC"] = CONFIG["CXX"] +$CFLAGS = $CXXFLAGS + +abort "cuda.h not found" if with_cuda && !find_header("cuda.h") +supports_c10_cuda = with_cuda && try_compile(<<~CPP) + #include + #include + + int main() { + c10::cuda::set_device(0); + return 0; + } +CPP + +if supports_c10_cuda + $defs << " -DHAVE_C10_CUDA" +end + $LDFLAGS += " -Wl,-rpath,#{lib}" if RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i && Dir.exist?("/opt/homebrew/opt/libomp/lib") $LDFLAGS += ",-rpath,/opt/homebrew/opt/libomp/lib" diff --git a/ext/torch/tensor.cpp b/ext/torch/tensor.cpp index 390b5a9c..353ae7cd 100644 --- a/ext/torch/tensor.cpp +++ b/ext/torch/tensor.cpp @@ -1,9 +1,13 @@ +#include #include +#include #include #include #include +#include +#include #include "tensor_functions.h" #include "ruby_arg_parser.h" @@ -26,6 +30,114 @@ Array flat_data(Tensor& tensor) { } Rice::Class rb_cTensor; +Rice::Class rb_cHookHandle; + +namespace { + +struct RubyTensorHook { + explicit RubyTensorHook(VALUE proc) : proc_(proc) { + rb_gc_register_address(&proc_); + } + + // The autograd engine can invoke hooks from threads not created by Ruby. + // Register the calling thread with Ruby before acquiring the GVL to avoid + // "rb_thread_call_with_gvl() is called by non-ruby thread" crashes. + static void ensure_ruby_thread_registered() { + // ruby_init_stack is idempotent and safe to call repeatedly; it ensures the + // current native thread is known to the VM before we try to grab the GVL. + volatile VALUE stack_anchor = Qnil; + ruby_init_stack(const_cast(&stack_anchor)); + } + + ~RubyTensorHook() { + rb_gc_unregister_address(&proc_); + } + + at::Tensor call(const at::Tensor& grad) { + ensure_ruby_thread_registered(); + HookCallData data{proc_, grad}; + rb_thread_call_with_gvl(&RubyTensorHook::invoke, &data); + if (data.return_value_defined) { + return data.return_tensor; + } + return grad; + } + + private: + struct HookCallData { + VALUE proc; + at::Tensor grad; + at::Tensor return_tensor; + bool return_value_defined = false; + }; + + static void* invoke(void* arg) { + auto* data = reinterpret_cast(arg); + VALUE grad_obj = Rice::detail::To_Ruby().convert(data->grad); + VALUE result = rb_funcall(data->proc, rb_intern("call"), 1, grad_obj); + if (!NIL_P(result)) { + data->return_tensor = Rice::detail::From_Ruby().convert(result); + data->return_value_defined = true; + } + return nullptr; + } + + VALUE proc_; +}; + +class HookHandle { + public: + HookHandle(const at::Tensor& tensor, unsigned handle, std::shared_ptr hook) + : tensor_(tensor), handle_(handle), hook_(std::move(hook)), removed_(false) {} + + HookHandle(const HookHandle& other) = default; + HookHandle& operator=(const HookHandle& other) = default; + + ~HookHandle() { + remove(); + } + + void remove() { + if (!removed_) { + tensor_.remove_hook(handle_); + removed_ = true; + hook_.reset(); + } + } + + private: + at::Tensor tensor_; + unsigned handle_; + std::shared_ptr hook_; + bool removed_; +}; + +VALUE tensor_register_hook(int argc, VALUE* argv, VALUE self_) { + HANDLE_TH_ERRORS + VALUE callable = Qnil; + rb_scan_args(argc, argv, "01", &callable); + if (NIL_P(callable)) { + if (rb_block_given_p()) { + callable = rb_block_proc(); + } else { + rb_raise(rb_eArgError, "Expected a callable or block"); + } + } + if (!rb_respond_to(callable, rb_intern("call"))) { + rb_raise(rb_eArgError, "Hook must respond to call"); + } + + Tensor& self = Rice::detail::From_Ruby().convert(self_); + auto hook = std::make_shared(callable); + unsigned handle = self.register_hook([hook](const at::Tensor& grad) { + return hook->call(grad); + }); + + return Rice::Data_Object(new HookHandle(self, handle, hook), true, rb_cHookHandle); + END_HANDLE_TH_ERRORS +} + +} // namespace std::vector index_vector(Array a) { Object obj; @@ -102,7 +214,17 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions add_tensor_functions(rb_cTensor); THPVariableClass = rb_cTensor.value(); + auto rb_mAutograd = Rice::define_module_under(m, "Autograd"); + rb_cHookHandle = Rice::define_class_under(rb_mAutograd, "RemovableHandle") + .define_method( + "remove", + [](HookHandle& self) { + self.remove(); + return Rice::Nil; + }); + rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1); + rb_define_method(rb_cTensor, "register_hook", (VALUE (*)(...)) tensor_register_hook, -1); rb_cTensor .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); }) diff --git a/ext/torch/torch.cpp b/ext/torch/torch.cpp index 20e4c4d4..ded002af 100644 --- a/ext/torch/torch.cpp +++ b/ext/torch/torch.cpp @@ -1,8 +1,14 @@ #include +#include +#include #include #include #include +#include +#include + +#include #include #include @@ -76,6 +82,30 @@ void init_torch(Rice::Module& m) { input.close(); return torch::pickle_load(bytes); }) + .define_singleton_function( + "_load_with_device", + [](const std::string &filename, const std::string &device_str) { + std::ifstream input(filename, std::ios::binary); + std::vector bytes( + (std::istreambuf_iterator(input)), + (std::istreambuf_iterator())); + input.close(); + + auto device = c10::Device(device_str); + auto reader = std::make_shared( + bytes.data(), + static_cast(bytes.size())); + caffe2::serialize::PyTorchStreamReader stream_reader(reader); + + return torch::jit::readArchiveAndTensors( + "data", + /*pickle_prefix=*/"", + /*tensor_prefix=*/"", + /*type_resolver=*/std::nullopt, + /*obj_loader=*/std::nullopt, + /*device=*/device, + stream_reader); + }) .define_singleton_function( "_from_blob", [](Rice::String s, const std::vector &size, const torch::TensorOptions &options) { diff --git a/lib/torch.rb b/lib/torch.rb index 266c2859..667315a5 100644 --- a/lib/torch.rb +++ b/lib/torch.rb @@ -9,6 +9,7 @@ # modules require_relative "torch/device" +require_relative "torch/accelerator" require_relative "torch/inspector" require_relative "torch/tensor" require_relative "torch/version" @@ -399,11 +400,26 @@ def save(obj, f) File.binwrite(f, _save(to_ivalue(obj))) end - def load(filename) + def load(filename, map_location: nil, weights_only: false) # keep backwards compatibility File.open(filename, "rb") { |f| f.read(1) } - to_ruby(_load(filename)) + load_device = map_location_device(map_location) if map_location + result = + if load_device + device_str = + if load_device.respond_to?(:_str) + load_device._str + else + load_device.to_s + end + to_ruby(_load_with_device(filename, device_str)) + else + to_ruby(_load(filename)) + end + ensure_weights_only_contents!(result) if weights_only + result = apply_map_location(result, map_location) if map_location + result end def tensor(data, **options) @@ -536,6 +552,133 @@ def to_ruby(ivalue) end end + WEIGHTS_ONLY_PRIMITIVE_CLASSES = + [ + NilClass, + TrueClass, + FalseClass, + Integer, + Float, + String + ].freeze + + def ensure_weights_only_contents!(obj) + case obj + when *WEIGHTS_ONLY_PRIMITIVE_CLASSES + obj + when Tensor + obj + when Array + obj.each { |value| ensure_weights_only_contents!(value) } + when Hash + obj.each do |key, value| + ensure_weights_only_contents!(key) + ensure_weights_only_contents!(value) + end + else + raise Error, "weights_only load supports tensors, primitive Ruby types, arrays, and hashes (found #{obj.class.name})" + end + end + + def map_location_device(map_location) + case map_location + when Device, String, Symbol + normalize_map_location_device(map_location) + when Hash + devices = + map_location.values.map do |value| + normalize_map_location_device(value) + rescue StandardError + nil + end.compact + return nil if devices.empty? + devices.uniq! + devices.one? ? devices.first : nil + else + nil + end + end + + def apply_map_location(obj, map_location) + case obj + when Tensor + map_tensor_location(obj, map_location) + when Array + obj.map { |value| apply_map_location(value, map_location) } + when Hash + obj.each_with_object({}) do |(key, value), memo| + memo[apply_map_location(key, map_location)] = apply_map_location(value, map_location) + end + else + obj + end + end + + def map_tensor_location(tensor, map_location) + case map_location + when nil + tensor + when Hash + target = lookup_map_location_target(map_location, tensor.device) + return tensor if target.nil? + map_tensor_location(tensor, target) + else + return map_tensor_location_callable(tensor, map_location) if map_location.respond_to?(:call) + device = normalize_map_location_device(map_location) + tensor.to(device) + end + end + + def map_tensor_location_callable(tensor, callable) + mapped = callable.call(tensor, map_location_device_tag(tensor.device)) + return tensor if mapped.nil? + unless mapped.is_a?(Tensor) + raise Error, "map_location callable must return a Tensor or nil (got #{mapped.class.name})" + end + mapped + end + + def lookup_map_location_target(mapping, device) + key = map_location_device_tag(device) + mapping.each do |candidate, value| + candidate_key = + case candidate + when Device + map_location_device_tag(candidate) + when String, Symbol + candidate.to_s + else + candidate + end + return value if candidate_key == key + end + nil + end + + def map_location_device_tag(device) + case device + when Device + tag = device.type + tag += ":#{device.index}" unless device.index.nil? + tag + when String, Symbol + device.to_s + else + raise Error, "Unknown device reference: #{device.inspect}" + end + end + + def normalize_map_location_device(location) + case location + when Device + location + when String, Symbol + device(location.to_s) + else + raise Error, "Unsupported map_location: #{location.inspect}" + end + end + def tensor_size(size) size.flatten end diff --git a/lib/torch/accelerator.rb b/lib/torch/accelerator.rb new file mode 100644 index 00000000..abfd95bb --- /dev/null +++ b/lib/torch/accelerator.rb @@ -0,0 +1,20 @@ +module Torch + module Accelerator + class << self + def current_accelerator(check_available: false) + device = _current_device + return nil unless device + return nil if check_available && !available? + device + end + + def device_count + _device_count + end + + def available? + _is_available + end + end + end +end diff --git a/lib/torch/device.rb b/lib/torch/device.rb index 45a822a8..0621e463 100644 --- a/lib/torch/device.rb +++ b/lib/torch/device.rb @@ -8,7 +8,10 @@ def inspect extra = ", index: #{index.inspect}" if index? "device(type: #{type.inspect}#{extra})" end - alias_method :to_s, :inspect + + def to_s + _str + end def ==(other) eql?(other) @@ -22,4 +25,20 @@ def hash [type, index].hash end end + + # String-like wrapper that also exposes device metadata + class DeviceString < String + def initialize(device) + @device = device + super(device._str) + end + + def type + @device.type + end + + def index + @device.index + end + end end diff --git a/lib/torch/distributed.rb b/lib/torch/distributed.rb new file mode 100644 index 00000000..d69e0396 --- /dev/null +++ b/lib/torch/distributed.rb @@ -0,0 +1,454 @@ +require "socket" +require "rbconfig" + +module Torch + module Distributed + DEFAULT_DEVICE_BACKENDS = { + "cpu" => "gloo", + "cuda" => "nccl", + "xpu" => "xccl", + "mps" => "gloo" + }.freeze + + SPAWN_ENV_KEY = "TORCH_DISTRIBUTED_SPAWNED".freeze + SPAWN_RANK_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_RANK".freeze + SPAWN_WORLD_SIZE_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_WORLD_SIZE".freeze + SPAWN_PORT_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_PORT".freeze + SPAWN_PIPE_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_PIPE".freeze + SPAWN_SCRIPT_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_SCRIPT".freeze + SPAWN_TEST_ENV_KEY = "TORCH_DISTRIBUTED_SPAWN_TEST".freeze + SPAWN_ARGV = ARGV.dup.freeze + + class << self + def initialized? + _initialized? + end + + def init_process_group(backend = nil, init_method: "env://", store: nil, rank: nil, world_size: nil, timeout: DEFAULT_TIMEOUT, wait_for_workers: true, device_id: nil) + raise Torch::Error, "torch.distributed is not available" unless available? + + backend ||= default_backend_for(device_id) + + if store.nil? + case init_method + when "env://" + rank = Integer(ENV.fetch("RANK")) if rank.nil? + world_size = Integer(ENV.fetch("WORLD_SIZE")) if world_size.nil? + master_addr = ENV.fetch("MASTER_ADDR", "127.0.0.1") + master_port = Integer(ENV.fetch("MASTER_PORT", "29500")) + raise ArgumentError, "rank is required" if rank.nil? + raise ArgumentError, "world_size is required" if world_size.nil? + is_master = rank.zero? + store = TCPStore.new(master_addr, master_port, world_size, is_master, wait_for_workers: wait_for_workers, timeout: timeout) + else + raise ArgumentError, "store is required when using init_method=#{init_method.inspect}" + end + end + + raise ArgumentError, "rank is required" if rank.nil? + raise ArgumentError, "world_size is required" if world_size.nil? + + device_id ||= default_device_id_for_backend(backend, rank, world_size) + + timeout_ms = (timeout * 1000).to_i + bound_device_id = device_id.nil? ? -1 : Integer(device_id) + if backend == "nccl" && bound_device_id >= 0 && Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:set_device) + device_count = Torch::CUDA.device_count if Torch::CUDA.respond_to?(:device_count) + # Only attempt to switch devices when the requested id exists to avoid + # raising on hosts with fewer GPUs than the provided local rank. + Torch::CUDA.set_device(bound_device_id) if device_count.nil? || bound_device_id < device_count + end + pg = _init_process_group(backend, store, rank, world_size, timeout_ms, bound_device_id) + warmup_process_group(pg, backend) + end + + def destroy_process_group + _destroy_process_group + end + + def default_process_group + _default_process_group + end + + def get_world_size(group = nil) + ensure_process_group!(group) + _get_world_size(group) + end + + def get_rank(group = nil) + ensure_process_group!(group) + _get_rank(group) + end + + def barrier(group: nil) + ensure_process_group!(group) + _barrier(group) + end + + def all_reduce(tensor, op: ReduceOp::SUM, group: nil) + ensure_process_group!(group) + _all_reduce(tensor, op, group) + end + + def broadcast(tensor, src:, group: nil) + ensure_process_group!(group) + _broadcast(tensor, src, group) + end + + def register_ddp_hook(tensor, process_group, world_size) + ensure_process_group!(process_group) + _register_ddp_hook(tensor, process_group, Integer(world_size)) + rescue NoMethodError + # Fallback for environments built without the native helper; this may + # still call back into Ruby from autograd threads. + tensor.register_hook do |grad| + all_reduce(grad, group: process_group) + grad.div!(world_size.to_f) + end + end + + def get_default_backend_for_device(device) + backend = DEFAULT_DEVICE_BACKENDS[device_type_from(device)] + raise ArgumentError, "Default backend not registered for device: #{device.inspect}" unless backend + backend + end + + def fork_world(world_size, host: "127.0.0.1", start_method: :fork, &block) + raise ArgumentError, "world_size must be positive" unless world_size.to_i.positive? + raise ArgumentError, "block required" unless block + + start_method = normalize_start_method(start_method) + return run_spawn_worker(&block) if start_method == :spawn && spawn_worker? + + fork_spawn_world(world_size, host: host, start_method: start_method, &block) + end + + def fork_spawn_world(world_size, host:, start_method:, &block) + port = free_port(host: host) + readers = [] + pids = [] + pgid = nil + completed = false + + begin + world_size.times do |rank| + reader, writer = IO.pipe + begin + case start_method + when :fork + pids << fork_worker(reader, writer, rank, port, world_size, &block) + when :spawn + pid, pgid = spawn_worker(reader, writer, rank, port, host: host, world_size: world_size, pgid: pgid) + pids << pid + else + raise ArgumentError, "Unsupported start_method: #{start_method.inspect}" + end + readers << reader + writer.close unless writer.closed? + rescue Exception + reader.close unless reader.closed? + writer.close unless writer.closed? + raise + end + end + + read_failure = Object.new + + outputs = readers.map do |reader| + begin + Marshal.load(reader) + rescue EOFError + read_failure + ensure + reader.close unless reader.closed? + end + end + + statuses = pids.each_with_index.map do |pid, idx| + _pid, status = Process.wait2(pid) + [idx, pid, status] + end + + statuses.each do |idx, pid, status| + output = outputs[idx] + if output.equal?(read_failure) + raise Torch::Error, "Child #{pid} closed pipe before sending result (status #{status.exitstatus})" + end + if !status.success? || (output.is_a?(Hash) && output[:error]) + message = if output.is_a?(Hash) && output[:error] + "Child #{pid} failed: #{output[:error]}\n#{Array(output[:backtrace]).join("\n")}" + else + "Child #{pid} exited with status #{status.exitstatus}" + end + raise Torch::Error, message + end + end + + completed = true + outputs + ensure + # Ensure child workers are cleaned up if an interrupt or error occurs. + terminate_processes(pids, pgid: pgid) unless completed + end + end + + def free_port(host: "127.0.0.1") + server = TCPServer.new(host, 0) + port = server.addr[1] + server.close + port + end + + private + + def ensure_process_group!(group) + return if group || initialized? + + raise Torch::Error, "Default process group is not initialized" + end + + def default_device_id_for_backend(backend, rank, world_size) + return unless backend == "nccl" + + default_local_rank(rank, world_size) + end + + def warmup_process_group(pg, backend) + return pg unless backend == "nccl" + + # Only warm up when a native process group was returned. + # Test helpers may stub out `_init_process_group` and return arbitrary + # Ruby objects, which cannot be passed to the C++ bindings. + return pg unless pg.nil? || (defined?(Torch::Distributed::ProcessGroup) && pg.is_a?(Torch::Distributed::ProcessGroup)) + + # Prime NCCL communicators so the first user-visible collective is fast + _barrier(pg) + pg + rescue + _destroy_process_group + raise + end + + def default_local_rank(rank, world_size) + local_rank = env_integer("LOCAL_RANK") + return local_rank unless local_rank.nil? + + local_world_size = env_integer("LOCAL_WORLD_SIZE") || world_size + return unless local_world_size && rank + + rank % local_world_size if local_world_size.positive? + end + + def env_integer(key) + Integer(ENV[key]) if ENV.key?(key) + rescue ArgumentError + nil + end + + def default_backend_for(device_id) + get_default_backend_for_device(device_id) + end + + def device_type_from(device) + case device + when Torch::Device + device.type + when NilClass + accelerator_type || "cpu" + when String + Torch.device(device).type + when Integer + return accelerator_type || "cpu" if device.negative? + if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:device_count) + max = Torch::CUDA.device_count + return accelerator_type || "cpu" if max <= 0 || device >= max + return Torch.device("cuda:#{device}").type + end + accelerator_type || "cpu" + else + return device.type if device.respond_to?(:type) + Torch.device(device).type + end + rescue => e + raise ArgumentError, "Invalid device #{device.inspect}: #{e.message}" + end + + def accelerator_type + acc = Torch::Accelerator.current_accelerator + acc.type if acc && acc.respond_to?(:type) + rescue + nil + end + + def normalize_start_method(start_method) + method = start_method&.to_sym + return method if [:fork, :spawn].include?(method) + + raise ArgumentError, "start_method must be :fork or :spawn (got #{start_method.inspect})" + end + + def spawn_worker? + ENV[SPAWN_ENV_KEY] == "1" + end + + def run_spawn_worker(&block) + rank = Integer(ENV.fetch(SPAWN_RANK_ENV_KEY)) + port = Integer(ENV.fetch(SPAWN_PORT_ENV_KEY)) + pipe_fd = Integer(ENV.fetch(SPAWN_PIPE_ENV_KEY)) + + writer = IO.new(pipe_fd, "wb") + writer.binmode + writer.sync = true + + result = block.call(rank, port) + Marshal.dump(result, writer) + writer.flush + writer.close + Process.exit!(0) + rescue Exception => e + begin + if defined?(writer) && writer && !writer.closed? + Marshal.dump({error: "#{e.class}: #{e.message}", backtrace: e.backtrace}, writer) + writer.flush + writer.close + end + rescue StandardError + # best-effort error reporting back to parent + ensure + Process.exit!(1) + end + end + + def fork_worker(reader, writer, rank, port, world_size, &block) + fork do + reader.close + begin + ENV["LOCAL_RANK"] = rank.to_s + ENV["LOCAL_WORLD_SIZE"] = world_size.to_s + ENV["RANK"] = rank.to_s + ENV["WORLD_SIZE"] = world_size.to_s + writer.binmode + writer.sync = true + result = block.call(rank, port) + Marshal.dump(result, writer) + writer.flush + writer.close + Process.exit!(0) + rescue => e + Marshal.dump({error: "#{e.class}: #{e.message}", backtrace: e.backtrace}, writer) + writer.flush + writer.close + Process.exit!(1) + ensure + writer.close unless writer.closed? + end + end + end + + def spawn_worker(reader, writer, rank, port, host:, world_size:, pgid: nil) + writer.binmode + writer.close_on_exec = false + + script = ENV[SPAWN_SCRIPT_ENV_KEY] || $0 + env = { + SPAWN_ENV_KEY => "1", + SPAWN_RANK_ENV_KEY => rank.to_s, + SPAWN_WORLD_SIZE_ENV_KEY => world_size.to_s, + SPAWN_PORT_ENV_KEY => port.to_s, + SPAWN_PIPE_ENV_KEY => writer.fileno.to_s, + "LOCAL_RANK" => rank.to_s, + "LOCAL_WORLD_SIZE" => world_size.to_s, + "MASTER_ADDR" => host, + "MASTER_PORT" => port.to_s, + "RANK" => rank.to_s, + "WORLD_SIZE" => world_size.to_s + } + env["RUBYLIB"] = [ENV["RUBYLIB"], $LOAD_PATH.join(File::PATH_SEPARATOR)].compact.reject(&:empty?).join(File::PATH_SEPARATOR) + + spawn_opts = {close_others: false} + spawn_opts[:pgroup] = pgid ? pgid : true + + pid = Process.spawn(env, RbConfig.ruby, script, *spawn_argv, spawn_opts) + pgid ||= pid + [pid, pgid] + rescue SystemCallError => e + raise Torch::Error, "failed to spawn worker #{rank}: #{e.message}" + end + + def spawn_argv + test_filter = ENV[SPAWN_TEST_ENV_KEY] + return SPAWN_ARGV unless test_filter + return SPAWN_ARGV if SPAWN_ARGV.include?("-n") + + # Restrict child to the specific test that triggered the spawn + SPAWN_ARGV + ["-n", test_filter] + end + + def terminate_processes(pids, pgid: nil) + return if pids.empty? && !pgid + + send_process_group_signal(pgid, "TERM") + pids.each { |pid| safe_kill(pid, "TERM") } + sleep(0.2) + pids.each do |pid| + next unless process_alive?(pid) + + safe_kill(pid, "KILL") + end + pids.each do |pid| + begin + Process.wait(pid) + rescue Errno::ECHILD + end + end + end + + def send_process_group_signal(pgid, sig) + return unless pgid + + Process.kill(sig, -pgid) + rescue Errno::ESRCH + end + + def safe_kill(pid, sig) + Process.kill(sig, pid) + rescue Errno::ESRCH + end + + def process_alive?(pid) + Process.kill(0, pid) + true + rescue Errno::ESRCH + false + end + end + + class TCPStore + def self.new(host, port, world_size, is_master, wait_for_workers: true, timeout: DEFAULT_TIMEOUT) + Torch::Distributed._create_tcp_store(host, port, world_size, is_master, (timeout * 1000).to_i, wait_for_workers) + end + end + + class FileStore + def self.new(path, world_size) + Torch::Distributed._create_file_store(path, world_size) + end + end + + if respond_to?(:_create_hash_store) + class HashStore + def self.new + Torch::Distributed._create_hash_store + end + end + end + end +end + +at_exit do + begin + Torch::Distributed.destroy_process_group if Torch::Distributed.available? && Torch::Distributed.initialized? + rescue Exception + # best-effort cleanup to avoid leaked process groups + end +end diff --git a/lib/torch/nn/module_list.rb b/lib/torch/nn/module_list.rb index 925bab5b..02c17575 100644 --- a/lib/torch/nn/module_list.rb +++ b/lib/torch/nn/module_list.rb @@ -6,7 +6,7 @@ class ModuleList < Module def initialize(mods = nil) super() - self.concat(mods) if mods + concat(mods) if mods end def length @@ -31,6 +31,10 @@ def each(&block) end end + def map(&block) + @modules.values.map(&block) + end + def append(mod) raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module) add_module(length.to_s, mod) diff --git a/lib/torch/nn/parallel/distributed_data_parallel.rb b/lib/torch/nn/parallel/distributed_data_parallel.rb new file mode 100644 index 00000000..dd5e0245 --- /dev/null +++ b/lib/torch/nn/parallel/distributed_data_parallel.rb @@ -0,0 +1,115 @@ +module Torch + module NN + module Parallel + class DistributedDataParallel < Module + attr_reader :module, :process_group + + def initialize(mod, device_ids: nil, process_group: nil, broadcast_buffers: true) + super() + raise Torch::Error, "torch.distributed is not available" unless Torch::Distributed.available? + + @module = mod + @broadcast_buffers = broadcast_buffers + @process_group = process_group || Torch::Distributed.default_process_group + raise Torch::Error, "Process group must be initialized before using DistributedDataParallel" unless @process_group + + @world_size = Torch::Distributed.get_world_size(@process_group) + @rank = Torch::Distributed.get_rank(@process_group) + @device = normalize_device(Array(device_ids).compact.first) + move_to_device(@device) if @device + + synchronize_parameters + @hook_handles = register_parameter_hooks + end + + def forward(*inputs, **kwargs) + outputs = @module.call(*move_inputs(inputs), **move_kwargs(kwargs)) + broadcast_buffers_if_needed + outputs + end + + alias_method :call, :forward + + def train(mode = true) + @module.train(mode) + broadcast_buffers_if_needed + self + end + + private + + def normalize_device(device) + return nil unless device + return device if device.is_a?(Torch::Device) + + if device.is_a?(Integer) + if Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available? + return Torch.device("cuda:#{device}") + end + end + + Torch.device(device) + end + + def move_to_device(device) + return unless device + + @module.to(device) + end + + def move_inputs(inputs) + return inputs unless @device + + inputs.map { |value| move_value(value, @device) } + end + + def move_kwargs(kwargs) + return kwargs unless @device + + kwargs.transform_values { |value| move_value(value, @device) } + end + + def move_value(value, device) + case value + when Torch::Tensor + value.to(device) + when Array + value.map { |v| move_value(v, device) } + when Hash + value.transform_values { |v| move_value(v, device) } + else + value + end + end + + def synchronize_parameters + Torch::Distributed.barrier(group: @process_group) + Torch.no_grad do + @module.parameters.each do |param| + Torch::Distributed.broadcast(param, src: 0, group: @process_group) + end + broadcast_buffers_if_needed + end + end + + def broadcast_buffers_if_needed + return unless @broadcast_buffers + + Torch.no_grad do + @module.buffers.each do |buffer| + Torch::Distributed.broadcast(buffer, src: 0, group: @process_group) + end + end + end + + def register_parameter_hooks + @module.parameters.filter_map do |param| + next unless param.requires_grad? + + Torch::Distributed.register_ddp_hook(param, @process_group, @world_size) + end + end + end + end + end +end diff --git a/lib/torch/tensor.rb b/lib/torch/tensor.rb index ed8ab71e..cfc4b63c 100644 --- a/lib/torch/tensor.rb +++ b/lib/torch/tensor.rb @@ -115,7 +115,8 @@ def item if numel != 1 raise Error, "only one element tensors can be converted to Ruby scalars" end - to_a.first + # use flatten to handle tensors with a single element but multiple dimensions + to_a.flatten.first end def to_i @@ -211,9 +212,8 @@ def coerce(other) end end - # TODO return Device instead of String in 0.19.0 def device - _device._str + DeviceString.new(_device) end end end diff --git a/lib/torch/torchrun.rb b/lib/torch/torchrun.rb new file mode 100644 index 00000000..9155b892 --- /dev/null +++ b/lib/torch/torchrun.rb @@ -0,0 +1,530 @@ +# frozen_string_literal: true + +require "optparse" +require "socket" +require "etc" +require "securerandom" +require "rbconfig" + +require_relative "../torch" + +module Torch + module TorchRun + SIGNALS = %w[INT TERM QUIT].freeze + + class Error < StandardError; end + + class Parser + attr_reader :parser + + def initialize + @parser = OptionParser.new + end + + def parse(argv) + options = default_options + + parser.banner = "Usage: torchrun [options] TRAINING_SCRIPT [script args]" + parser.separator "" + parser.separator "Launch parameters:" + + parser.on("--nnodes MIN[:MAX]", String, "Number of nodes or range (default: #{options[:nnodes]})") do |value| + options[:nnodes] = value + end + + parser.on("--nproc-per-node VALUE", String, "Processes per node (int, gpu, cpu, auto). Default: #{options[:nproc_per_node]}") do |value| + options[:nproc_per_node] = value + end + + parser.on("--node-rank VALUE", Integer, "Rank of the node for multi-node jobs. Default: #{options[:node_rank]}") do |value| + options[:node_rank] = value + end + + parser.on("--rdzv-backend NAME", String, "Rendezvous backend (static or c10d). Default: #{options[:rdzv_backend]}") do |value| + options[:rdzv_backend] = value + end + + parser.on("--rdzv-endpoint HOST[:PORT]", String, "Rendezvous endpoint. Default: use --master-addr/--master-port") do |value| + options[:rdzv_endpoint] = value + end + + parser.on("--rdzv-id ID", String, "User defined job id. Default: #{options[:rdzv_id]}") do |value| + options[:rdzv_id] = value + end + + parser.on("--rdzv-conf CONF", String, "Additional rendezvous config (k=v,k2=v2)") do |value| + options[:rdzv_conf] = parse_kv_pairs(value) + end + + parser.on("--standalone", "Start a local rendezvous store on a free port") do + options[:standalone] = true + end + + parser.on("--max-restarts VALUE", Integer, "Restarts before failing. Default: #{options[:max_restarts]}") do |value| + options[:max_restarts] = value + end + + parser.on("--monitor-interval SECONDS", Float, "Delay between restart attempts. Default: #{options[:monitor_interval]}") do |value| + options[:monitor_interval] = value + end + + parser.on("--role NAME", String, "Role for the worker group. Default: #{options[:role]}") do |value| + options[:role] = value + end + + parser.on("--master-addr HOST", String, "Master address for static rendezvous. Default: #{options[:master_addr]}") do |value| + options[:master_addr] = value + end + + parser.on("--master-port PORT", Integer, "Master port for static rendezvous. Default: #{options[:master_port]}") do |value| + options[:master_port] = value + end + + parser.on("--pass-local-rank-arg", "Append --local-rank to the training script invocation") do + options[:pass_local_rank_arg] = true + end + + parser.on("--no-ruby", "Execute the training script directly instead of `#{RbConfig.ruby}`") do + options[:no_ruby] = true + end + + parser.on("-h", "--help", "Prints this help") do + puts parser + exit + end + + rest = parser.parse!(argv) + raise OptionParser::MissingArgument, "training_script" if rest.empty? + + training_script = rest.shift + [options, training_script, rest] + end + + def to_s + parser.to_s + end + + private + + def default_options + { + nnodes: "1:1", + nproc_per_node: "1", + node_rank: 0, + rdzv_backend: "static", + rdzv_endpoint: "", + rdzv_id: "none", + rdzv_conf: {}, + standalone: false, + max_restarts: 0, + monitor_interval: 1.0, + role: "default", + master_addr: "127.0.0.1", + master_port: 29_500, + pass_local_rank_arg: false, + no_ruby: false + } + end + + def parse_kv_pairs(value) + return {} if value.nil? || value.strip.empty? + + value.split(",").each_with_object({}) do |pair, acc| + key, val = pair.split("=", 2) + raise OptionParser::InvalidArgument, "Invalid rendezvous config entry: #{pair.inspect}" unless key && val + + acc[key.strip] = val.strip + end + end + end + + module_function + + def start(argv, out: $stdout, err: $stderr) + parser = Parser.new + options, script, script_args = parser.parse(argv) + status = Launcher.new(options, script, script_args, out: out, err: err).run + exit(status) + rescue OptionParser::ParseError => e + err.puts(e.message) + err.puts(parser) + exit(2) + rescue Error => e + err.puts("torchrun: #{e.message}") + exit(1) + end + + class Launcher + def initialize(options, script, script_args, out: $stdout, err: $stderr) + @options = options + @script = script + @script_args = script_args + @out = out + @err = err + + @local_world_size = determine_local_world_size(@options[:nproc_per_node]) + @min_nodes, @max_nodes = parse_nnodes(@options[:nnodes]) + @num_nodes = ensure_fixed_nnodes(@min_nodes, @max_nodes) + @node_rank = @options[:node_rank] + @max_restarts = [@options[:max_restarts], 0].max + @monitor_interval = [@options[:monitor_interval], 0.0].max + @role = @options[:role] + @pass_local_rank_arg = @options[:pass_local_rank_arg] + @no_ruby = @options[:no_ruby] + validate_node_rank! + + setup_rendezvous! + end + + def run + restarts = 0 + + loop do + status = launch_worker_group(restarts) + return status if status.zero? || @signal_received + return status if restarts >= @max_restarts + + restarts += 1 + log("Worker group failed (exit #{status}). Restarting #{restarts}/#{@max_restarts} ...") + sleep(@monitor_interval) if @monitor_interval.positive? + end + end + + private + + def launch_worker_group(restart_count) + @signal_received = nil + @current_pids = spawn_workers(restart_count) + handler_state = setup_signal_handlers + status = monitor_workers(@current_pids.dup) + cleanup_workers(@current_pids) + restore_signal_handlers(handler_state) + return signal_exit_status if @signal_received + + status + ensure + @worker_pgid = nil + @current_pids = [] + end + + def spawn_workers(restart_count) + base_env = base_environment(restart_count) + pgid = nil + workers = Array.new(@local_world_size) do |local_rank| + env = base_env.merge(rank_environment(local_rank)) + pid, pgid = spawn_worker(env, local_rank, pgid) + pid + end + @worker_pgid = pgid + workers + end + + def spawn_worker(env, local_rank, pgid) + args = command_arguments(local_rank) + spawn_opts = pgid ? { pgroup: pgid } : { pgroup: true } + pid = Process.spawn(env, *args, spawn_opts) + pgid ||= pid + [pid, pgid] + rescue SystemCallError => e + raise Error, "failed to launch worker #{local_rank}: #{e.message}" + end + + def command_arguments(local_rank) + cmd = [] + if @no_ruby + cmd << @script + else + cmd << RbConfig.ruby + cmd << @script + end + cmd.concat(@script_args) + cmd << "--local-rank=#{local_rank}" if @pass_local_rank_arg + cmd + end + + def base_environment(restart_count) + endpoint = "#{@master_addr}:#{@master_port}" + env = { + "MASTER_ADDR" => @master_addr, + "MASTER_PORT" => @master_port.to_s, + "WORLD_SIZE" => world_size.to_s, + "LOCAL_WORLD_SIZE" => @local_world_size.to_s, + "GROUP_RANK" => @node_rank.to_s, + "TORCHRUN_ROLE" => @role, + "TORCHRUN_NNODES" => @num_nodes.to_s, + "TORCHRUN_NPROC_PER_NODE" => @local_world_size.to_s, + "TORCHELASTIC_RUN_ID" => @rdzv_id, + "TORCHRUN_RDZV_BACKEND" => @rdzv_backend, + "TORCHRUN_RDZV_ENDPOINT" => endpoint, + "TORCHELASTIC_RESTART_COUNT" => restart_count.to_s, + "TORCHRUN_STANDALONE" => @standalone ? "1" : "0" + } + unless @rdzv_conf.empty? + env["TORCHRUN_RDZV_CONF"] = @rdzv_conf.map { |k, v| "#{k}=#{v}" }.join(",") + end + ENV.to_h.merge(env) + end + + def rank_environment(local_rank) + rank = @node_rank * @local_world_size + local_rank + { + "LOCAL_RANK" => local_rank.to_s, + "RANK" => rank.to_s + } + end + + def monitor_workers(pids) + exit_code = 0 + remaining = pids.dup + until remaining.empty? + pid, status = Process.wait2 + next unless pid + + remaining.delete(pid) + unless status.success? + exit_code = exit_status_from(status) + terminate_workers(remaining) + break + end + end + exit_code + rescue Errno::ECHILD + 0 + end + + def terminate_workers(pids) + return if pids.empty? + + send_process_group_signal("TERM") + pids.each { |pid| send_signal(pid, "TERM") } + sleep(0.2) + pids.each do |pid| + next unless process_alive?(pid) + + send_signal(pid, "KILL") + end + pids.each do |pid| + begin + Process.wait(pid) + rescue Errno::ECHILD + end + end + end + + def process_alive?(pid) + Process.kill(0, pid) + true + rescue Errno::ESRCH + false + end + + def setup_signal_handlers + SIGNALS.each_with_object({}) do |sig, acc| + next unless Signal.list.key?(sig) + + previous = Signal.trap(sig) do + @signal_received = sig + forward_signal(sig) + end + acc[sig] = previous + end + end + + def forward_signal(sig) + send_process_group_signal(sig) + (@current_pids || []).each { |pid| send_signal(pid, sig) } + end + + def restore_signal_handlers(state) + return unless state + + state.each do |sig, previous| + Signal.trap(sig, previous) + end + end + + def send_signal(pid, sig) + Process.kill(sig, pid) + rescue Errno::ESRCH + nil + end + + def send_process_group_signal(sig) + return unless @worker_pgid + + Process.kill(sig, -@worker_pgid) + rescue Errno::ESRCH + nil + end + + def cleanup_workers(pids) + pids.each do |pid| + next unless process_alive?(pid) + + begin + Process.wait(pid) + rescue Errno::ECHILD + end + end + end + + def signal_exit_status + return 0 unless @signal_received + + 128 + Signal.list.fetch(@signal_received, 0) + end + + def exit_status_from(status) + if status.exited? + status.exitstatus + elsif status.signaled? + 128 + status.termsig + else + 1 + end + end + + def determine_local_world_size(value) + spec = value.to_s.strip.downcase + case spec + when "", "1" + 1 + when /\A\d+\z/ + amount = spec.to_i + raise Error, "nproc-per-node must be >= 1" if amount < 1 + + amount + when "gpu" + gpu_count = cuda_device_count + raise Error, "CUDA is not available for --nproc-per-node=gpu" if gpu_count.zero? + + gpu_count + when "auto" + gpu_count = cuda_device_count + return gpu_count if gpu_count.positive? + + cpu_count + when "cpu" + cpu_count + else + raise Error, "Unsupported --nproc-per-node value: #{value}" + end + end + + def cuda_device_count + return 0 unless defined?(Torch::CUDA) + return 0 unless Torch::CUDA.respond_to?(:available?) && Torch::CUDA.available? + return 0 unless Torch::CUDA.respond_to?(:device_count) + + Torch::CUDA.device_count + rescue StandardError + 0 + end + + def cpu_count + Etc.respond_to?(:nprocessors) ? (Etc.nprocessors || 1) : 1 + rescue StandardError + 1 + end + + def parse_nnodes(value) + parts = value.split(":") + nums = parts.map do |part| + Integer(part, exception: false) + end + raise Error, "Invalid --nnodes value: #{value.inspect}" if nums.any?(&:nil?) + + if nums.length == 1 + [nums.first, nums.first] + elsif nums.length == 2 + [nums.first, nums.last] + else + raise Error, "Invalid --nnodes value: #{value.inspect}" + end + end + + def ensure_fixed_nnodes(min_nodes, max_nodes) + raise Error, "--nnodes minimum must be >= 1" if min_nodes < 1 + raise Error, "--nnodes maximum must be >= minimum" if max_nodes < min_nodes + raise Error, "Elastic nnodes ranges are not supported yet (got #{min_nodes}:#{max_nodes})" if min_nodes != max_nodes + + min_nodes + end + + def world_size + @world_size ||= @num_nodes * @local_world_size + end + + def validate_node_rank! + raise Error, "--node-rank must be >= 0" if @node_rank.negative? + raise Error, "--node-rank (#{@node_rank}) must be less than --nnodes (#{@num_nodes})" if @node_rank >= @num_nodes + end + + def setup_rendezvous! + @rdzv_backend = normalize_backend(@options[:rdzv_backend]) + @rdzv_conf = @options[:rdzv_conf] || {} + if @options[:standalone] + configure_standalone_rendezvous + else + configure_static_rendezvous + end + end + + def normalize_backend(value) + backend = value.to_s.downcase + raise Error, "Unsupported rendezvous backend: #{value.inspect}" unless %w[static c10d].include?(backend) + + backend + end + + def configure_standalone_rendezvous + @standalone = true + @rdzv_backend = "c10d" + @rdzv_id = SecureRandom.uuid + @master_addr = "127.0.0.1" + @master_port = find_free_port(@master_addr) + log(<<~MSG) + + ************************************** + Rendezvous info: + --rdzv-backend=#{@rdzv_backend} + --rdzv-endpoint=#{@master_addr}:#{@master_port} + --rdzv-id=#{@rdzv_id} + ************************************** + + MSG + end + + def configure_static_rendezvous + @standalone = false + endpoint_host, endpoint_port = parse_endpoint(@options[:rdzv_endpoint]) + @master_addr = endpoint_host || @options[:master_addr] + @master_port = endpoint_port || @options[:master_port] + @rdzv_id = @options[:rdzv_id] + raise Error, "MASTER_ADDR must be provided" if @master_addr.to_s.empty? + raise Error, "MASTER_PORT must be > 0" unless @master_port.to_i.positive? + end + + def parse_endpoint(value) + return [nil, nil] if value.nil? || value.strip.empty? + + host, port_str = value.split(":", 2) + port = port_str ? Integer(port_str, exception: false) : nil + raise Error, "Invalid rendezvous endpoint: #{value.inspect}" if host.to_s.empty? || (port_str && port.nil?) + + [host, port] + end + + def find_free_port(host) + server = TCPServer.new(host, 0) + server.addr[1] + ensure + server&.close + end + + def log(message) + @out.puts(message) + end + end + end +end diff --git a/test/device_test.rb b/test/device_test.rb index 69f778f6..b31b3348 100644 --- a/test/device_test.rb +++ b/test/device_test.rb @@ -22,4 +22,9 @@ def test_inspect assert_equal %!device(type: "cpu")!, Torch.device("cpu").inspect assert_equal %!device(type: "cpu", index: 0)!, Torch.device("cpu:0").inspect end + + def test_to_s + assert_equal "cpu", Torch.device("cpu").to_s + assert_equal "cpu:0", Torch.device("cpu:0").to_s + end end diff --git a/test/save_test.rb b/test/save_test.rb index a7438e03..640fdf25 100644 --- a/test/save_test.rb +++ b/test/save_test.rb @@ -55,6 +55,61 @@ def test_load_missing assert_equal "No such file or directory @ rb_sysopen - missing.bin", error.message end + def test_load_with_map_location_string + tmpfile = Tempfile.new + tensor = Torch.tensor([1, 2, 3]) + Torch.save(tensor, tmpfile.path) + loaded = Torch.load(tmpfile.path, map_location: "cpu") + assert_equal tensor.to_a, loaded.to_a + end + + def test_load_with_map_location_callable + tmpfile = Tempfile.new + tensor = Torch.tensor([1, 2, 3]) + Torch.save(tensor, tmpfile.path) + seen = [] + loaded = Torch.load(tmpfile.path, map_location: lambda { |value, loc| + seen << loc + value + }) + assert_equal tensor.to_a, loaded.to_a + assert_equal ["cpu"], seen + end + + def test_load_with_weights_only + tmpfile = Tempfile.new + tensor = Torch.tensor([1, 2, 3]) + Torch.save(tensor, tmpfile.path) + loaded = Torch.load(tmpfile.path, weights_only: true) + assert_equal tensor.to_a, loaded.to_a + end + + def test_load_map_location_cuda_to_cpu + skip "Requires CUDA" unless Torch::CUDA.available? + + tmpfile = Tempfile.new + tensor = Torch.tensor([1, 2, 3]).cuda + Torch.save(tensor, tmpfile.path) + + loaded = Torch.load(tmpfile.path, map_location: "cpu") + assert_equal "cpu", loaded.device.type + assert_equal tensor.cpu.to_a, loaded.to_a + end + + def test_load_map_location_cpu_to_cuda + skip "Requires CUDA" unless Torch::CUDA.available? + + tmpfile = Tempfile.new + tensor = Torch.tensor([1, 2, 3]) + Torch.save(tensor, tmpfile.path) + + device = "cuda:0" + loaded = Torch.load(tmpfile.path, map_location: device) + assert_equal "cuda", loaded.device.type + assert_equal 0, loaded.device.index + assert_equal tensor.to_a, loaded.cpu.to_a + end + private def assert_save(obj) diff --git a/test/support/scripts/show_ranks.rb b/test/support/scripts/show_ranks.rb new file mode 100644 index 00000000..6654dfcb --- /dev/null +++ b/test/support/scripts/show_ranks.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +$stdout.sync = true +rank = ENV.fetch("RANK", "unknown") +local_rank = ENV.fetch("LOCAL_RANK", "unknown") +world_size = ENV.fetch("WORLD_SIZE", "unknown") +puts "RANK=#{rank} LOCAL_RANK=#{local_rank} WORLD_SIZE=#{world_size}" diff --git a/test/torchrun_test.rb b/test/torchrun_test.rb new file mode 100644 index 00000000..a3cf7a38 --- /dev/null +++ b/test/torchrun_test.rb @@ -0,0 +1,33 @@ +# frozen_string_literal: true + +require_relative "test_helper" + +require "open3" +require "rbconfig" + +class TorchRunTest < Minitest::Test + def test_standalone_launches_multiple_workers + script = File.expand_path("support/scripts/show_ranks.rb", __dir__) + torchrun = File.expand_path("../bin/torchrun", __dir__) + stdout, stderr, status = Open3.capture3( + {"TORCHRUN_TEST" => "1"}, + RbConfig.ruby, + torchrun, + "--standalone", + "--nproc-per-node=2", + script + ) + + assert status.success?, "torchrun failed: #{stderr}" + + lines = stdout.lines.map(&:strip).select { |line| line.start_with?("RANK=") } + assert_equal 2, lines.size, "expected two worker outputs, got: #{lines.inspect}" + ranks = lines.map do |line| + match = line.match(/RANK=(\d+)\s+LOCAL_RANK=(\d+)\s+WORLD_SIZE=(\d+)/) + raise "unexpected output: #{line}" unless match + + [match[1].to_i, match[2].to_i, match[3].to_i] + end + assert_equal [[0, 0, 2], [1, 1, 2]], ranks.sort + end +end