diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 83990340..122631ba 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -52,7 +52,7 @@ def _maybe_aqt_einsum(quant: Quant): class AttentionOp(nn.Module): mesh: Mesh attention_kernel: str - scale: int + scale: float heads: int dim_head: int use_memory_efficient_attention: bool = False @@ -60,9 +60,9 @@ class AttentionOp(nn.Module): float32_qk_product: bool = True flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None + flash_block_sizes: BlockSizes | None = None dtype: DType = jnp.float32 - quant: Quant = None + quant: Quant | None = None def setup(self): if self.attention_kernel == "cudnn_flash_te": @@ -79,7 +79,7 @@ def setup(self): dtype=self.dtype, # float32_logits=self.float32_logits, qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=self.scale, + scale_factor=float(self.scale), transpose_batch_sequence=False, ) @@ -415,15 +415,15 @@ class FlaxFluxAttention(nn.Module): split_head_dim: bool = False attention_kernel: str = "dot_product" flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None + flash_block_sizes: BlockSizes | None = None + mesh: jax.sharding.Mesh | None = None dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) - precision: jax.lax.Precision = None + precision: jax.lax.Precision | None = None qkv_bias: bool = False def setup(self): @@ -619,16 +619,16 @@ class FlaxAttention(nn.Module): split_head_dim: bool = False attention_kernel: str = "dot_product" flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None + flash_block_sizes: BlockSizes | None = None + mesh: jax.sharding.Mesh | None = None dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - precision: jax.lax.Precision = None - quant: Quant = None + precision: jax.lax.Precision | None = None + quant: Quant | None = None def setup(self): @@ -762,10 +762,10 @@ class FlaxBasicTransformerBlock(nn.Module): split_head_dim: bool = False attention_kernel: str = "dot_product" flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - precision: jax.lax.Precision = None - quant: Quant = None + flash_block_sizes: BlockSizes | None = None + mesh: jax.sharding.Mesh | None = None + precision: jax.lax.Precision | None = None + quant: Quant | None = None def setup(self): # self attention (or cross_attention if only_cross_attention is True) @@ -890,12 +890,12 @@ class FlaxTransformer2DModel(nn.Module): split_head_dim: bool = False attention_kernel: str = "dot_product" flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None + flash_block_sizes: BlockSizes | None = None + mesh: jax.sharding.Mesh | None = None norm_num_groups: int = 32 - precision: jax.lax.Precision = None + precision: jax.lax.Precision | None = None hidden_state_axis_names: AxisNames = (BATCH, LENGTH, D_KV) - quant: Quant = (None,) + quant: Quant | tuple[None] = (None,) def setup(self): self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) @@ -1019,7 +1019,7 @@ class FlaxFeedForward(nn.Module): dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None + precision: jax.lax.Precision | None = None def setup(self): # The second linear layer needs to be called @@ -1051,7 +1051,7 @@ class FlaxGEGLU(nn.Module): dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None + precision: jax.lax.Precision | None = None def setup(self): inner_dim = self.dim * 4 diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 1337f232..dae3b8e9 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -22,8 +22,7 @@ def get_first_step(state): - with jax.spmd_mode("allow_all"): - return int(state.step) + return int(state.step) def load_next_batch(train_iter, example_batch, config): @@ -101,27 +100,27 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step def write_metrics_to_tensorboard(writer, metrics, step, config): """Writes metrics to tensorboard""" - with jax.spmd_mode("allow_all"): - if jax.process_index() == 0: - for metric_name in metrics.get("scalar", []): - writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) - for metric_name in metrics.get("scalars", []): - writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) - - full_log = step % config.log_period == 0 - if jax.process_index() == 0: - max_logging.log( - "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format( - step, - metrics["scalar"]["perf/step_time_seconds"], - metrics["scalar"]["perf/per_device_tflops_per_sec"], - float(metrics["scalar"]["learning/loss"]), - ) - ) - if full_log and jax.process_index() == 0: - max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") - writer.flush() + if jax.process_index() == 0: + for metric_name in metrics.get("scalar", []): + writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) + for metric_name in metrics.get("scalars", []): + writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) + + full_log = step % config.log_period == 0 + if jax.process_index() == 0: + max_logging.log( + "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format( + step, + metrics["scalar"]["perf/step_time_seconds"], + metrics["scalar"]["perf/per_device_tflops_per_sec"], + float(metrics["scalar"]["learning/loss"]), + ) + ) + + if full_log and jax.process_index() == 0: + max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") + writer.flush() def get_params_to_save(params): diff --git a/src/maxdiffusion/utils/outputs.py b/src/maxdiffusion/utils/outputs.py index ee4e5f26..6a08ca9e 100644 --- a/src/maxdiffusion/utils/outputs.py +++ b/src/maxdiffusion/utils/outputs.py @@ -16,7 +16,7 @@ """ from collections import OrderedDict -from dataclasses import fields, is_dataclass +from dataclasses import dataclass, fields, is_dataclass from typing import Any, Tuple import numpy as np @@ -37,6 +37,7 @@ def is_tensor(x): return isinstance(x, np.ndarray) +@dataclass class BaseOutput(OrderedDict): """ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a diff --git a/src/maxdiffusion/utils/pil_utils.py b/src/maxdiffusion/utils/pil_utils.py index bf376818..a9b72818 100644 --- a/src/maxdiffusion/utils/pil_utils.py +++ b/src/maxdiffusion/utils/pil_utils.py @@ -16,11 +16,11 @@ } else: PIL_INTERPOLATION = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - "nearest": PIL.Image.NEAREST, + "linear": PIL.Image.LINEAR, # pytype: disable=module-attr + "bilinear": PIL.Image.BILINEAR, # pytype: disable=module-attr + "bicubic": PIL.Image.BICUBIC, # pytype: disable=module-attr + "lanczos": PIL.Image.LANCZOS, # pytype: disable=module-attr + "nearest": PIL.Image.NEAREST, # pytype: disable=module-attr } @@ -50,7 +50,7 @@ def numpy_to_pil(images): return pil_images -def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: +def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int | None = None) -> PIL.Image.Image: """ Prepares a single grid of images. Useful for visualization purposes. """