Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tsgm/models/architectures/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def _build_generator(self) -> keras.models.Model:
x = layers.Conv1D(1, 8, padding="same")(x)
x = layers.LSTM(256, return_sequences=True)(x)

pool_and_stride = round((x.shape[1] + 1) / (self._seq_len + 1))
pool_and_stride = math.ceil((x.shape[1] + 1) / (self._seq_len + 1))

x = layers.AveragePooling1D(pool_size=pool_and_stride, strides=pool_and_stride)(x)
g_output = layers.Conv1D(self._feat_dim, 1, activation="tanh")(x)
Expand Down Expand Up @@ -604,7 +604,7 @@ def _build_model(self) -> keras.models.Model:
for _ in range(self._n_conv_lstm_blocks):
x = layers.Conv1D(filters=64, kernel_size=3, activation="relu")(x)
x = layers.Dropout(0.2)(x)
x = layers.LSTM(128, activation="relu", return_sequences=True)(x)
x = layers.LSTM(128, activation="tanh", return_sequences=True)(x)
x = layers.Dropout(0.2)(x)
x = layers.Flatten()(x)
x = layers.Dense(128, activation="relu")(x)
Expand Down
160 changes: 107 additions & 53 deletions tsgm/models/cgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_di
:type latent_dim: int
:param temporal: Indicates whether the time series temporally labeled or not.
:type temporal: bool
:param use_wgan: Use Wasserstein GAN with gradient penalty
:type use_wgan: bool
"""
super(ConditionalGAN, self).__init__()
self.discriminator = discriminator
Expand All @@ -355,6 +357,9 @@ def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_di
self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
self._temporal = temporal

self.use_wgan = use_wgan
self.gp_weight = 10.0

def call(self, inputs):
"""
Forward pass for the ConditionalGAN model.
Expand Down Expand Up @@ -398,6 +403,51 @@ def compile(self, d_optimizer: keras.optimizers.Optimizer, g_optimizer: keras.op

self.dp = generator_dp and discriminator_dp

def wgan_discriminator_loss(self, real_sample, fake_sample):
real_loss = ops.mean(real_sample)
fake_loss = ops.mean(fake_sample)
return fake_loss - real_loss

# Define the loss functions to be used for generator
def wgan_generator_loss(self, fake_sample):
return -ops.mean(fake_sample)

def gradient_penalty_tf(self, tf, interpolated):
with tf.GradientTape() as gp_tape:
gp_tape.watch(interpolated)
# 1. Get the discriminator output for this interpolated sample.
pred = self.discriminator(interpolated, training=True)

# 2. Calculate the gradients w.r.t to this interpolated sample.
grads = gp_tape.gradient(pred, [interpolated])[0]
return grads

def gradient_penalty_torch(self, torch, interpolated):
# Create a new tensor that requires grad instead of modifying existing one
interpolated = interpolated.detach().requires_grad_(True)
pred = self.discriminator(interpolated, training=True)
grads = torch.autograd.grad(outputs=pred, inputs=interpolated,
grad_outputs=ops.ones_like(pred),
create_graph=True, retain_graph=True, only_inputs=True)[0]
return grads

def gradient_penalty(self, batch_size, real_samples, fake_samples):
# get the interpolated samples
alpha_shape = [batch_size, 1, 1]
# Create alpha on the same device as real_samples
alpha = ops.ones_like(real_samples[:, :1, :1]) * keras.random.normal(alpha_shape, 0.0, 1.0)
diff = fake_samples - real_samples
interpolated = real_samples + alpha * diff
backend = get_backend()
if os.environ.get("KERAS_BACKEND") == "tensorflow":
grads = self.gradient_penalty_tf(backend, interpolated)
elif os.environ.get("KERAS_BACKEND") == "torch":
grads = self.gradient_penalty_torch(backend, interpolated)
# 3. Calcuate the norm of the gradients
norm = ops.sqrt(ops.sum(ops.square(grads), axis=[1, 2]))
gp = ops.mean((norm - 1.0) ** 2)
return gp

def _get_random_vector_labels(self, batch_size: int, labels: tsgm.types.Tensor) -> None:
if self._temporal:
random_latent_vectors = keras.random.normal(shape=(batch_size, self._seq_len, self.latent_dim))
Expand Down Expand Up @@ -425,63 +475,63 @@ def train_step_tf(self, tf, data: T.Tuple) -> T.Dict[str, float]:
labels = data[1]
output_dim = self._get_output_shape(labels)
batch_size = ops.shape(real_ts)[0]

# Prepare labels
if not self._temporal:
rep_labels = labels[:, :, None]
rep_labels = ops.repeat(
rep_labels, repeats=[self._seq_len]
)
rep_labels = ops.repeat(rep_labels, repeats=[self._seq_len])
else:
rep_labels = labels

rep_labels = ops.reshape(
rep_labels, (-1, self._seq_len, output_dim)
)
rep_labels = ops.reshape(rep_labels, (-1, self._seq_len, output_dim))

# Generate ts
# Generate Fake Data
random_vector_labels = self._get_random_vector_labels(batch_size=batch_size, labels=labels)
generated_ts = self.generator(random_vector_labels)

# Concatenate TS with Labels for the Discriminator
fake_data = ops.concatenate([generated_ts, rep_labels], -1)
real_data = ops.concatenate([real_ts, rep_labels], -1)
combined_data = ops.concatenate(
[fake_data, real_data], axis=0
)

# Labels for descriminator
# 1 == real data
# 0 == fake data
desc_labels = ops.concatenate(
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
)

# Combined data (used for standard GAN only)
combined_data = ops.concatenate([fake_data, real_data], axis=0)

# 1. Train Discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_data)
d_loss = self.loss_fn(desc_labels, predictions)
if self.use_wgan:
fake_logits = self.discriminator(fake_data, training=True)
real_logits = self.discriminator(real_data, training=True)
d_cost = self.wgan_discriminator_loss(real_logits, fake_logits)

# GP is calculated on the CONCATENATED data (ts + labels)
gp = self.gradient_penalty(batch_size, real_data, fake_data)
d_loss = d_cost + gp * self.gp_weight
else:
desc_labels = ops.concatenate([ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0)
predictions = self.discriminator(combined_data)
d_loss = self.loss_fn(desc_labels, predictions)

if self.dp:
# For DP optimizers from `tensorflow.privacy`
self.d_optimizer.minimize(d_loss, self.discriminator.trainable_weights, tape=tape)
else:
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)

# 2. Train Generator
random_vector_labels = self._get_random_vector_labels(batch_size=batch_size, labels=labels)

# Pretend that all samples are real
misleading_labels = ops.zeros((batch_size, 1))

# Train generator (with updating the discriminator)
with tf.GradientTape() as tape:
fake_samples = self.generator(random_vector_labels)
fake_data = ops.concatenate([fake_samples, rep_labels], -1)
predictions = self.discriminator(fake_data)
g_loss = self.loss_fn(misleading_labels, predictions)

if self.use_wgan:
g_loss = self.wgan_generator_loss(predictions)
else:
g_loss = self.loss_fn(misleading_labels, predictions)

if self.dp:
# For DP optimizers from `tensorflow.privacy`
self.g_optimizer.minimize(g_loss, self.generator.trainable_weights, tape=tape)
else:
grads = tape.gradient(g_loss, self.generator.trainable_weights)
Expand All @@ -501,58 +551,63 @@ def train_step_torch(self, torch, data: T.Tuple) -> T.Dict[str, float]:
else:
# Fallback for single input
real_ts, labels = data, None

output_dim = self._get_output_shape(labels)
batch_size = ops.shape(real_ts)[0]

# Prepare labels
if not self._temporal:
rep_labels = labels[:, :, None]
rep_labels = ops.repeat(
rep_labels, repeats=[self._seq_len]
)
rep_labels = ops.repeat(rep_labels, repeats=[self._seq_len])
else:
rep_labels = labels

rep_labels = ops.reshape(
rep_labels, (-1, self._seq_len, output_dim)
)
rep_labels = ops.reshape(rep_labels, (-1, self._seq_len, output_dim))

# Generate ts
# Generate Fake Data
random_vector_labels = self._get_random_vector_labels(batch_size=batch_size, labels=labels)
generated_ts = self.generator(random_vector_labels)

# Concatenate TS with Labels
fake_data = ops.concatenate([generated_ts, rep_labels], -1)
real_data = ops.concatenate([real_ts, rep_labels], -1)
combined_data = ops.concatenate(
[fake_data, real_data], axis=0
)
combined_data = ops.concatenate([fake_data, real_data], axis=0)

# Labels for descriminator
# 1 == real data
# 0 == fake data
desc_labels = ops.concatenate(
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
)
predictions = self.discriminator(combined_data)
d_loss = self.loss_fn(desc_labels, predictions)
# 1. Train Discriminator
if self.use_wgan:
fake_logits = self.discriminator(fake_data, training=True)
real_logits = self.discriminator(real_data, training=True)
d_cost = self.wgan_discriminator_loss(real_logits, fake_logits)
gp = self.gradient_penalty(batch_size, real_data, fake_data)
d_loss = d_cost + gp * self.gp_weight
else:
desc_labels = ops.concatenate([ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0)
predictions = self.discriminator(combined_data)
d_loss = self.loss_fn(desc_labels, predictions)

self.discriminator.zero_grad()
d_loss.backward()

d_trainable_weights = [v for v in self.discriminator.trainable_weights]
d_gradients = [v.value.grad for v in d_trainable_weights]

with torch.no_grad():
# Keras 3 expects (gradient, variable) pairs
grads_and_vars = list(zip(d_gradients, d_trainable_weights))
self.d_optimizer.apply_gradients(grads_and_vars)
random_vector_labels = self._get_random_vector_labels(batch_size=batch_size, labels=labels)

# Pretend that all samples are real
# 2. Train Generator
random_vector_labels = self._get_random_vector_labels(batch_size=batch_size, labels=labels)
misleading_labels = ops.zeros((batch_size, 1))

# Train generator (with updating the discriminator)
# Re-generate to keep computational graph valid for generator
fake_samples = self.generator(random_vector_labels)
fake_data = ops.concatenate([fake_samples, rep_labels], -1)
predictions = self.discriminator(fake_data)
g_loss = self.loss_fn(misleading_labels, predictions)

if self.use_wgan:
g_loss = self.wgan_generator_loss(predictions)
else:
g_loss = self.loss_fn(misleading_labels, predictions)

self.generator.zero_grad()
g_loss.backward()
Expand All @@ -561,7 +616,6 @@ def train_step_torch(self, torch, data: T.Tuple) -> T.Dict[str, float]:
g_gradients = [v.value.grad for v in g_trainable_weights]

with torch.no_grad():
# Keras 3 expects (gradient, variable) pairs
grads_and_vars = list(zip(g_gradients, g_trainable_weights))
self.g_optimizer.apply_gradients(grads_and_vars)

Expand Down
6 changes: 3 additions & 3 deletions tsgm/models/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def train_step_tf(self, tf, data: tsgm.types.Tensor) -> T.Dict:
reconstruction_loss = self._get_reconstruction_loss(data, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
total_loss = reconstruction_loss + self.beta * kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
# I am not sure if this should be self.optimizer.apply(grads, model.trainable_weights)
# see https://keras.io/guides/writing_a_custom_training_loop_in_tensorflow/
Expand All @@ -95,7 +95,7 @@ def train_step_torch(self, torch, data: tsgm.types.Tensor) -> T.Dict:
reconstruction_loss = self._get_reconstruction_loss(data, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
total_loss = reconstruction_loss + self.beta * kl_loss
# Ensure total_loss is a scalar for PyTorch backward()
if hasattr(total_loss, 'shape') and len(total_loss.shape) > 0:
total_loss = ops.mean(total_loss)
Expand Down Expand Up @@ -124,7 +124,7 @@ def train_step_jax(self, jax, data: tsgm.types.Tensor) -> T.Dict:
reconstruction_loss = self._get_reconstruction_loss(data, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
total_loss = reconstruction_loss + self.beta * kl_loss

self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
Expand Down