From 5edf5573e7b8b78a65a16e432655018d49fd6e34 Mon Sep 17 00:00:00 2001 From: John Mai Date: Sat, 21 Jun 2025 13:18:38 +0800 Subject: [PATCH 1/7] feat: Add support for Falcon H1 model --- Libraries/MLXLLM/LLMModelFactory.swift | 1 + Libraries/MLXLLM/Models/FalconH1.swift | 1057 ++++++++++++++++++++++++ 2 files changed, 1058 insertions(+) create mode 100644 Libraries/MLXLLM/Models/FalconH1.swift diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index ddd8d1b6..95913ed4 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -46,6 +46,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable { "mimo": create(MiMoConfiguration.self, MiMoModel.init), "glm4": create(GLM4Configuration.self, GLM4Model.init), "acereason": create(Qwen2Configuration.self, Qwen2Model.init), + "falcon_h1": create(FalconH1Configuration.self, FalconH1Model.init), ] } diff --git a/Libraries/MLXLLM/Models/FalconH1.swift b/Libraries/MLXLLM/Models/FalconH1.swift new file mode 100644 index 00000000..b947037b --- /dev/null +++ b/Libraries/MLXLLM/Models/FalconH1.swift @@ -0,0 +1,1057 @@ +// +// FalconH1.swift +// mlx-swift-examples +// +// Created by John Mai on 2025/6/18. +// + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/falcon_h1.py + +// MARK: - RMSNormGated + +private class RMSNormGated: Module { + let weight: MLXArray + let varianceEpsilon: Float + let nGroups: Int + let normBeforeGate: Bool + + init(hiddenSize: Int, eps: Float = 1e-6, nGroups: Int = 1, normBeforeGate: Bool = true) { + self.weight = MLXArray.ones([hiddenSize]) + self.varianceEpsilon = eps + self.nGroups = nGroups + self.normBeforeGate = normBeforeGate + } + + func callAsFunction(_ hiddenStates: MLXArray, gate: MLXArray? = nil) -> MLXArray { + let inputDtype = hiddenStates.dtype + + var hiddenStates = hiddenStates + + if !normBeforeGate, let gate { + hiddenStates = hiddenStates * silu(gate.asType(.float16)) + } + + hiddenStates = MLXFast.rmsNorm(hiddenStates, weight: weight, eps: varianceEpsilon) + + if normBeforeGate, let gate { + hiddenStates = hiddenStates * silu(gate.asType(.float16)) + } + + return hiddenStates.asType(inputDtype) + } +} + +private func computeMupVector(_ args: FalconH1Configuration) -> MLXArray { + let intermediateSize = args.mambaDSSM ?? Int(Float(args.mambaExpand) * Float(args.hiddenSize)) + let groupsTimeStateSize = args.mambaNGroups * args.mambaDState + let numHeads = args.mambaNHeads + let zxbcdtMultipliers = args.ssmMultipliers + + let vectorShape = 2 * intermediateSize + 2 * groupsTimeStateSize + numHeads + let mupVector = MLXArray.ones([1, 1, vectorShape]) + + mupVector[0..., 0..., .. MLXArray + { + let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2)) + + var queries = qProj(x) + var keys = kProj(x) + var values = vProj(x) + + keys = keys * keyMultiplier + + queries = queries.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, numKVHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, numKVHeads, -1).transposed(0, 2, 1, 3) + + if let cache { + queries = rope(queries, offset: cache.seqlenOffset) + keys = rope(keys, offset: cache.seqlenOffset) + (keys, values) = cache.update(keyStates: keys, valueStates: values, layerIdx: layerIdx) + } else { + queries = rope(queries) + keys = rope(keys) + } + + if var mask { + let kvSeqLen = keys.dim(2) + if mask.ndim == 2 { + mask = mask[.newAxis, .newAxis, 0..., 0...] + } + + if kvSeqLen > L { + if mask.dim(-1) < kvSeqLen { + let numHeadsDim = mask.dim(1) > 1 ? mask.dim(1) : 1 + let padLength = kvSeqLen - mask.dim(-1) + let padShape = [B, numHeadsDim, L, padLength] + let padding = MLXArray.ones(padShape, dtype: mask.dtype) + mask = concatenated([padding, mask], axis: -1) + } + } + } + + var output = MLXFast.scaledDotProductAttention( + queries: queries, + keys: keys, + values: values, + scale: scale, + mask: mask + ) + + output = output.transposed(0, 2, 1, 3).reshaped(B, L, -1) + return oProj(output) + } +} + +// MARK: - MLP + +private class MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gateProj: Linear + @ModuleInfo(key: "up_proj") var upProj: Linear + @ModuleInfo(key: "down_proj") var downProj: Linear + + let gateMultiplier: Float + let downMultiplier: Float + + init(_ args: FalconH1Configuration) { + let hiddenSize = args.hiddenSize + let intermediateSize = args.intermediateSize + + _gateProj.wrappedValue = Linear(hiddenSize, intermediateSize, bias: args.mlpBias) + _upProj.wrappedValue = Linear(hiddenSize, intermediateSize, bias: args.mlpBias) + _downProj.wrappedValue = Linear(intermediateSize, hiddenSize, bias: args.mlpBias) + + self.gateMultiplier = args.mlpMultipliers[0] + self.downMultiplier = args.mlpMultipliers[1] + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let y = upProj(x) * silu(gateProj(x) * gateMultiplier) + return downProj(y) * downMultiplier + } +} + +// MARK: - DecoderLayer + +private class DecoderLayer: Module { + let mamba: Mixer + let channelsAttn: Int + let ssmOutMultiplier: Float + let attnOutMultiplier: Float + let attentionInMultiplier: Float + + @ModuleInfo(key: "feed_forward") var feedForward: MLP + @ModuleInfo(key: "self_attn") var attention: Attention + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "pre_ff_layernorm") var preFfLayerNorm: RMSNorm + + init(_ args: FalconH1Configuration, layerIdx: Int, mupVector: MLXArray) { + self.mamba = Mixer(args, layerIdx: layerIdx, mupVector: mupVector) + + let headDim = args.hiddenSize / args.numAttentionHeads + self.channelsAttn = args.numAttentionHeads * headDim + 2 * args.numKeyValueHeads * headDim + + self.attentionInMultiplier = args.attentionInMultiplier + self.ssmOutMultiplier = args.ssmOutMultiplier + self.attnOutMultiplier = args.attentionOutMultiplier + + _feedForward.wrappedValue = MLP(args) + _attention.wrappedValue = Attention(args, layerIdx: layerIdx) + _inputLayerNorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + _preFfLayerNorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + func callAsFunction( + _ hiddenStates: MLXArray, + cache: Mamba2Cache, + mask: MLXArray?, + mambaMask: MLXArray?, + cachePosition: MLXArray + ) -> MLXArray { + var residual = hiddenStates + var hiddenStates = inputLayerNorm(hiddenStates) + + let mambaHiddenStates = + mamba( + hiddenStates, + cache: cache, + mask: mambaMask, + cachePosition: cachePosition + ) * ssmOutMultiplier + + let attentionHiddenStates = + attention( + hiddenStates * attentionInMultiplier, + mask: mask, + cache: cache + ) * attnOutMultiplier + + hiddenStates = mambaHiddenStates + attentionHiddenStates + + hiddenStates = residual + hiddenStates + + residual = hiddenStates + hiddenStates = preFfLayerNorm(hiddenStates) + hiddenStates = feedForward(hiddenStates) + hiddenStates = residual + hiddenStates + + return hiddenStates + } +} + +private func applyMaskToPaddingStates(_ inputStates: MLXArray, _ attentionMask: MLXArray?) + -> MLXArray +{ + if let attentionMask { + let mask = expandedDimensions(attentionMask, axes: [-1]) + return inputStates * mask + } + return inputStates +} + +private func padTensorBySize(_ tensor: MLXArray, _ padSize: Int) -> MLXArray { + if padSize > 0 { + var padShape = tensor.shape + padShape[1] = padSize + let padding = MLXArray.zeros(padShape).asType(tensor.dtype) + return concatenated([tensor, padding], axis: 1) + } + return tensor +} + +private func reshapeIntoChunks(_ tensor: MLXArray, _ padSize: Int, _ chunkSize: Int) -> MLXArray { + var tensor = tensor + if padSize > 0 { + tensor = padTensorBySize(tensor, padSize) + } + + let batchSize = tensor.shape[0] + let seqLen = tensor.dim(1) + let numChunks = seqLen / chunkSize + + var newShape = [batchSize, numChunks, chunkSize] + newShape.append(contentsOf: Array(tensor.shape[2...])) + return tensor.reshaped(newShape) +} + +private func segmentSum(_ inputTensor: MLXArray) -> MLXArray { + let chunkSize = inputTensor.dim(-1) + var inputTensor = expandedDimensions(inputTensor, axes: [-1]) + inputTensor = broadcast( + inputTensor, to: inputTensor.shape[0 ..< inputTensor.ndim - 1] + [chunkSize]) + + var mask = tri(chunkSize, k: -1).asType(.bool) + inputTensor = MLX.where(mask, inputTensor, MLXArray.zeros(like: inputTensor)) + + let tensorSegsum = cumsum(inputTensor, axis: -2) + + mask = tri(chunkSize, k: 0).asType(.bool) + return MLX.where(mask, tensorSegsum, MLXArray(-Float.infinity)) +} + +// MARK: - Mixer + +private class Mixer: Module { + let numHeads: Int + let hiddenSize: Int + let ssmStateSize: Int + let convKernelSize: Int + let intermediateSize: Int + let layerIdx: Int + let useConvBias: Bool + let useBias: Bool + let layerNormEpsilon: Float + let groupsTimeStateSize: Int + let nGroups: Int + let headDim: Int + let chunkSize: Int + let timeStepLimit: (Float, Float) + let timeStepMin: Float + let timeStepMax: Float + let convDim: Int + let mambaRMSNorm: Bool + let norm: RMSNormGated? + let ssmInMultiplier: Float + let conv1d: Conv1d + + let _mupVector: MLXArray + + @ModuleInfo(key: "in_proj") var inProj: Linear + @ParameterInfo(key: "dt_bias") var dtBias: MLXArray + @ParameterInfo(key: "A_log") var aLog: MLXArray + @ParameterInfo(key: "D") var d: MLXArray + @ModuleInfo(key: "out_proj") var outProj: Linear + + init(_ args: FalconH1Configuration, layerIdx: Int, mupVector: MLXArray) { + self.numHeads = args.mambaNHeads + self.hiddenSize = args.hiddenSize + self.ssmStateSize = args.mambaDState + self.convKernelSize = args.mambaDConv + self.intermediateSize = args.mambaDSSM ?? Int(args.mambaExpand * args.hiddenSize) + self.layerIdx = layerIdx + self.useConvBias = args.mambaConvBias + self.useBias = args.mambaProjBias + self.layerNormEpsilon = args.rmsNormEps + self.groupsTimeStateSize = args.mambaNGroups * args.mambaDState + self.nGroups = args.mambaNGroups + self.headDim = args.mambaDHead + self.chunkSize = args.mambaChunkSize + self.timeStepLimit = (0.0, Float.infinity) + self.timeStepMin = 0.001 + self.timeStepMax = 0.1 + + self.convDim = intermediateSize + 2 * nGroups * ssmStateSize + + self.conv1d = Conv1d( + inputChannels: convDim, + outputChannels: convDim, + kernelSize: convKernelSize, + padding: convKernelSize - 1, + groups: convDim, + bias: useConvBias + ) + + let projectionSize = intermediateSize + convDim + numHeads + _inProj.wrappedValue = Linear( + hiddenSize, + projectionSize, + bias: args.mambaProjBias + ) + + _dtBias.wrappedValue = MLXArray.ones([numHeads]) + + let A = MLXArray(Array(1 ... numHeads)).asType(.float32) + + _aLog.wrappedValue = log(A) + + self.mambaRMSNorm = args.mambaRMSNorm + if mambaRMSNorm { + self.norm = RMSNormGated( + hiddenSize: intermediateSize, + eps: layerNormEpsilon, + nGroups: nGroups, + normBeforeGate: args.mambaNormBeforeGate + ) + } else { + self.norm = nil + } + + _d.wrappedValue = MLXArray.ones([numHeads]) + 1.0 + + _outProj.wrappedValue = Linear( + intermediateSize, + hiddenSize, + bias: args.projectorsBias + ) + + self.ssmInMultiplier = args.ssmInMultiplier + self._mupVector = mupVector + } + + func callAsFunction( + _ inputStates: MLXArray, cache: Mamba2Cache? = nil, mask: MLXArray? = nil, + cachePosition: MLXArray? = nil + ) -> MLXArray { + let (batchSize, seqLen, _) = (inputStates.dim(0), inputStates.dim(1), inputStates.dim(2)) + let dtype = inputStates.dtype + + let mask: MLXArray? = mask?[..<1, .ellipsis] + + var inputStates = applyMaskToPaddingStates(inputStates, mask) + + inputStates = inputStates * ssmInMultiplier + var projectedStates = inProj(inputStates) + projectedStates = projectedStates * _mupVector + + let gate = projectedStates[.ellipsis, .. 0 + + if usePrecomputedStates, let cache { + var convState = roll(cache.convStates[layerIdx]!, shift: -1, axis: -1) + convState[0..., 0..., -1] = hiddenStatesBC[0..., 0, 0...] + cache.convStates[layerIdx] = convState + + hiddenStatesBC = sum(convState * squeezed(conv1d.weight, axis: -1), axis: -1) + if useConvBias { + hiddenStatesBC = hiddenStatesBC + conv1d.bias! + } + hiddenStatesBC = silu(hiddenStatesBC) + } else { + if let cache { + let hiddenStatesBCTransposed = hiddenStatesBC.transposed(0, 2, 1) + let seqLenTransposed: Int = hiddenStatesBCTransposed.dim(-1) + let padSize = convKernelSize - seqLenTransposed + + let convStates: MLXArray = + if padSize > 0 { + padded( + hiddenStatesBCTransposed, + widths: [.init((0, 0)), .init((0, 0)), .init((padSize, 0))]) + } else { + hiddenStatesBCTransposed[0..., 0..., .. 0 { + y = y[0..., .. 0) + + _embedTokens.wrappedValue = Embedding(embeddingCount: vocabSize, dimensions: hiddenSize) + + let mupVector = computeMupVector(args) + self.layers = (0 ..< args.numHiddenLayers).map { layerIdx in + DecoderLayer(args, layerIdx: layerIdx, mupVector: mupVector) + } + + _finalLayerNorm.wrappedValue = RMSNorm(dimensions: hiddenSize, eps: args.rmsNormEps) + } + + func callAsFunction(_ inputs: MLXArray, mask: MLXArray? = nil, cache: [Mamba2Cache]? = nil) + -> MLXArray + { + var h = embedTokens(inputs) + + h = h * args.embeddingMultiplier + + let mask = mask ?? createAttentionMask(h: h, cache: nil) + let mambaMask: MLXArray? = nil + + let cachePosition = MLXArray(0 ..< h.dim(1)).asType(.int32) + + if h.dim(1) == 1, let cache { + let prevSeqlen = cache[0].keyCache[0].dim(-2) + let cachePosition = cachePosition + prevSeqlen + } + + for (layer, c) in zip(layers, cache!) { + h = layer( + h, + cache: c ?? Mamba2Cache(args), + mask: nil, + mambaMask: mambaMask, + cachePosition: cachePosition + ) + } + + return finalLayerNorm(h) + } +} + +public class FalconH1Model: Module, LLMModel, KVCacheDimensionProvider { + public let vocabularySize: Int + public let kvHeads: [Int] + + private let model: ModelInner + let configuration: FalconH1Configuration + + @ModuleInfo(key: "lm_head") var lmHead: Linear? + + public init(_ args: FalconH1Configuration) { + self.configuration = args + self.vocabularySize = args.vocabSize + self.kvHeads = (0 ..< args.numKeyValueHeads).map { _ in args.numHiddenLayers } + self.model = ModelInner(args) + + if !args.tieWordEmbeddings { + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabSize, bias: false) + } + } + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { + var out = model(inputs, cache: cache as? [Mamba2Cache]) + if let lmHead { + out = lmHead(out) * configuration.lmHeadMultiplier + } else { + out = model.embedTokens.asLinear(out) + } + + return out + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var weights = weights + for (name, param) in weights { + if name.contains("conv1d.weight"), param.dim(-1) > param.dim(1) { + weights[name] = param.transposed(0, 2, 1) + } + } + + return weights + } + + public func newCache(parameters: GenerateParameters?) -> [any KVCache] { + model.layers.map { _ in Mamba2Cache(configuration) } + } +} + +// MARK: - LoRA + +extension FalconH1Model: LoRAModel { + public func loraLinearLayers() -> LoRALinearLayers { + model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } +} + +// MARK: - Configuration + +public struct FalconH1Configuration: Codable, Sendable { + var attentionBias: Bool + var attentionDropout: Float + var attentionInMultiplier: Float + var attentionOutMultiplier: Float + var bosTokenId: Int + var embeddingMultiplier: Float + var eosTokenId: Int + var headDim: Int? + var hiddenAct: String + var hiddenSize: Int + var initializerRange: Float + var intermediateSize: Int + var keyMultiplier: Float + var lmHeadMultiplier: Float + var mambaChunkSize: Int + var mambaConvBias: Bool + var mambaDConv: Int + var mambaDHead: Int + var mambaDSSM: Int? + var mambaDState: Int + var mambaExpand: Int + var mambaNGroups: Int + var mambaNHeads: Int + var mambaNormBeforeGate: Bool + var mambaProjBias: Bool + var mambaRMSNorm: Bool + var mambaUseMLP: Bool + var maxPositionEmbeddings: Int + var mlpBias: Bool + var mlpExpansionFactor: Int + var mlpMultipliers: [Float] + var modelType: String + var numAttentionHeads: Int + var numHiddenLayers: Int + var numKeyValueHeads: Int + var numLogitsToKeep: Int + var padTokenId: Int + var projectorsBias: Bool + var rmsNormEps: Float + var ropeTraditional: Bool + var ropeScaling: Float? + var ropeTheta: Float + var ssmInMultiplier: Float + var ssmMultipliers: [Float] + var ssmOutMultiplier: Float + var tieWordEmbeddings: Bool + var torchDtype: String + var vocabSize: Int + + enum CodingKeys: String, CodingKey { + case attentionBias = "attention_bias" + case attentionDropout = "attention_dropout" + case attentionInMultiplier = "attention_in_multiplier" + case attentionOutMultiplier = "attention_out_multiplier" + case bosTokenId = "bos_token_id" + case embeddingMultiplier = "embedding_multiplier" + case eosTokenId = "eos_token_id" + case headDim = "head_dim" + case hiddenAct = "hidden_act" + case hiddenSize = "hidden_size" + case initializerRange = "initializer_range" + case intermediateSize = "intermediate_size" + case keyMultiplier = "key_multiplier" + case lmHeadMultiplier = "lm_head_multiplier" + case mambaChunkSize = "mamba_chunk_size" + case mambaConvBias = "mamba_conv_bias" + case mambaDConv = "mamba_d_conv" + case mambaDHead = "mamba_d_head" + case mambaDSSM = "mamba_d_ssm" + case mambaDState = "mamba_d_state" + case mambaExpand = "mamba_expand" + case mambaNGroups = "mamba_n_groups" + case mambaNHeads = "mamba_n_heads" + case mambaNormBeforeGate = "mamba_norm_before_gate" + case mambaProjBias = "mamba_proj_bias" + case mambaRMSNorm = "mamba_rms_norm" + case mambaUseMLP = "mamba_use_mlp" + case maxPositionEmbeddings = "max_position_embeddings" + case mlpBias = "mlp_bias" + case mlpExpansionFactor = "mlp_expansion_factor" + case mlpMultipliers = "mlp_multipliers" + case modelType = "model_type" + case numAttentionHeads = "num_attention_heads" + case numHiddenLayers = "num_hidden_layers" + case numKeyValueHeads = "num_key_value_heads" + case numLogitsToKeep = "num_logits_to_keep" + case padTokenId = "pad_token_id" + case projectorsBias = "projectors_bias" + case rmsNormEps = "rms_norm_eps" + case ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + case ropeTheta = "rope_theta" + case ssmInMultiplier = "ssm_in_multiplier" + case ssmMultipliers = "ssm_multipliers" + case ssmOutMultiplier = "ssm_out_multiplier" + case tieWordEmbeddings = "tie_word_embeddings" + case torchDtype = "torch_dtype" + case vocabSize = "vocab_size" + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.attentionBias = + try container.decodeIfPresent(Bool.self, forKey: .attentionBias) ?? false + self.attentionDropout = + try container.decodeIfPresent(Float.self, forKey: .attentionDropout) ?? 0.0 + self.attentionInMultiplier = + try container.decodeIfPresent(Float.self, forKey: .attentionInMultiplier) ?? 1.0 + self.attentionOutMultiplier = + try container.decodeIfPresent(Float.self, forKey: .attentionOutMultiplier) ?? 1.0 + self.bosTokenId = try container.decodeIfPresent(Int.self, forKey: .bosTokenId) ?? 1 + self.embeddingMultiplier = + try container.decodeIfPresent(Float.self, forKey: .embeddingMultiplier) ?? 1.0 + self.eosTokenId = try container.decodeIfPresent(Int.self, forKey: .eosTokenId) ?? 2 + self.headDim = try container.decodeIfPresent(Int.self, forKey: .headDim) ?? nil + self.hiddenAct = try container.decodeIfPresent(String.self, forKey: .hiddenAct) ?? "silu" + self.hiddenSize = try container.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 4096 + self.initializerRange = + try container.decodeIfPresent(Float.self, forKey: .initializerRange) ?? 0.02 + self.intermediateSize = + try container.decodeIfPresent(Int.self, forKey: .intermediateSize) ?? 14336 + self.keyMultiplier = + try container.decodeIfPresent(Float.self, forKey: .keyMultiplier) ?? 1.0 + self.lmHeadMultiplier = + try container.decodeIfPresent(Float.self, forKey: .lmHeadMultiplier) ?? 1.0 + self.mambaChunkSize = + try container.decodeIfPresent(Int.self, forKey: .mambaChunkSize) ?? 256 + self.mambaConvBias = + try container.decodeIfPresent(Bool.self, forKey: .mambaConvBias) ?? true + self.mambaDConv = try container.decodeIfPresent(Int.self, forKey: .mambaDConv) ?? 4 + self.mambaDHead = try container.decodeIfPresent(Int.self, forKey: .mambaDHead) ?? 64 + self.mambaDSSM = try container.decodeIfPresent(Int.self, forKey: .mambaDSSM) ?? nil + self.mambaDState = try container.decodeIfPresent(Int.self, forKey: .mambaDState) ?? 256 + self.mambaExpand = try container.decodeIfPresent(Int.self, forKey: .mambaExpand) ?? 2 + self.mambaNGroups = try container.decodeIfPresent(Int.self, forKey: .mambaNGroups) ?? 1 + self.mambaNHeads = try container.decodeIfPresent(Int.self, forKey: .mambaNHeads) ?? 128 + self.mambaNormBeforeGate = + try container.decodeIfPresent(Bool.self, forKey: .mambaNormBeforeGate) ?? true + self.mambaProjBias = + try container.decodeIfPresent(Bool.self, forKey: .mambaProjBias) ?? false + self.mambaRMSNorm = try container.decodeIfPresent(Bool.self, forKey: .mambaRMSNorm) ?? false + self.mambaUseMLP = try container.decodeIfPresent(Bool.self, forKey: .mambaUseMLP) ?? true + self.maxPositionEmbeddings = + try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 8192 + self.mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) ?? false + self.mlpExpansionFactor = + try container.decodeIfPresent(Int.self, forKey: .mlpExpansionFactor) ?? 8 + self.mlpMultipliers = + try container.decodeIfPresent([Float].self, forKey: .mlpMultipliers) ?? [1.0, 1.0] + self.modelType = + try container.decodeIfPresent(String.self, forKey: .modelType) ?? "falcon_h1" + self.numAttentionHeads = + try container.decodeIfPresent(Int.self, forKey: .numAttentionHeads) ?? 32 + self.numHiddenLayers = + try container.decodeIfPresent(Int.self, forKey: .numHiddenLayers) ?? 32 + self.numKeyValueHeads = + try container.decodeIfPresent(Int.self, forKey: .numKeyValueHeads) ?? 8 + self.numLogitsToKeep = + try container.decodeIfPresent(Int.self, forKey: .numLogitsToKeep) ?? 1 + self.padTokenId = try container.decodeIfPresent(Int.self, forKey: .padTokenId) ?? 0 + self.projectorsBias = + try container.decodeIfPresent(Bool.self, forKey: .projectorsBias) ?? false + self.rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1e-05 + self.ropeTraditional = + try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false + self.ropeScaling = try container.decodeIfPresent(Float?.self, forKey: .ropeScaling) ?? nil + self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 100000.0 + self.ssmInMultiplier = + try container.decodeIfPresent(Float.self, forKey: .ssmInMultiplier) ?? 1.0 + self.ssmMultipliers = + try container.decodeIfPresent([Float].self, forKey: .ssmMultipliers) ?? [ + 1.0, 1.0, 1.0, 1.0, 1.0, + ] + self.ssmOutMultiplier = + try container.decodeIfPresent(Float.self, forKey: .ssmOutMultiplier) ?? 1.0 + self.tieWordEmbeddings = + try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false + self.torchDtype = + try container.decodeIfPresent(String.self, forKey: .torchDtype) ?? "bfloat16" + self.vocabSize = try container.decodeIfPresent(Int.self, forKey: .vocabSize) ?? 128000 + } +} + +// MARK: - Mamba2Cache KVCache + +private class Mamba2Cache: KVCache { + var offset: Int + + var maxSize: Int? + + func innerState() -> [MLXArray] { + [] + } + + var seqlenOffset: Int = 0 + var hasPreviousState: Bool = false + let convKernelSize: Int + + private var _seenTokens: Int = 0 + + let intermediateSize: Int + + var convStates: [Int: MLXArray] + var ssmStates: [Int: MLXArray] + + var transformerLayers: [Int] + var keyCache: [MLXArray] + var valueCache: [MLXArray] + + init(_ args: FalconH1Configuration, batchSize: Int = 1) { + self.convKernelSize = args.mambaDConv + + self.intermediateSize = + args.mambaDSSM ?? Int(Float(args.mambaExpand) * Float(args.hiddenSize)) + + self.convStates = [:] + self.ssmStates = [:] + + for i in 0 ..< args.numHiddenLayers { + convStates[i] = MLXArray.zeros([ + batchSize, + intermediateSize + 2 * args.mambaNGroups * args.mambaDState, + convKernelSize, + ]) + ssmStates[i] = MLXArray.zeros([ + batchSize, + args.mambaNHeads, + args.mambaDHead, + args.mambaDState, + ]) + } + + self.seqlenOffset = 0 + self.hasPreviousState = false + self.transformerLayers = Array(0 ..< args.numHiddenLayers) + self.keyCache = [] + self.valueCache = [] + self.offset = 0 + } + + func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { + update(keyStates: keys, valueStates: values, layerIdx: 0) + } + + func update(keyStates: MLXArray, valueStates: MLXArray, layerIdx: Int) -> (MLXArray, MLXArray) { + if layerIdx == 0 { + _seenTokens += keyStates.dim(-2) + } + + if keyCache.count <= layerIdx { + for _ in keyCache.count ..< layerIdx { + keyCache.append(MLXArray([])) + valueCache.append(MLXArray([])) + } + keyCache.append(keyStates) + valueCache.append(valueStates) + } else if keyCache[layerIdx].size == 0 { + keyCache[layerIdx] = keyStates + valueCache[layerIdx] = valueStates + } else { + keyCache[layerIdx] = concatenated([keyCache[layerIdx], keyStates], axis: -2) + valueCache[layerIdx] = concatenated([valueCache[layerIdx], valueStates], axis: -2) + } + + return (keyCache[layerIdx], valueCache[layerIdx]) + } + + func updateConvState(layerIdx: Int, newConvState: MLXArray, cachePosition: MLXArray) -> MLXArray + { + var convState = convStates[layerIdx]! + let cachePosition = clip(cachePosition, min: 0, max: convKernelSize - 1) + + convState = roll(convState, shift: -1, axis: -1) + + if cachePosition.count > 1 { + convState[0..., 0..., 0...] = newConvState.transposed(0, 2, 1) + } else { + convState[0..., 0..., -1] = newConvState[0..., 0..., -1] + } + + convStates[layerIdx] = convState + return convStates[layerIdx]! + } + + func reset() { + for i in 0 ..< convStates.count { + convStates[i] = MLXArray.zeros(like: convStates[i]!) + ssmStates[i] = MLXArray.zeros(like: ssmStates[i]!) + } + } +} From ad9d92efb75f2a1f6797bdf7e71ec71950629449 Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 22 Jun 2025 13:11:16 +0800 Subject: [PATCH 2/7] fix generate --- Libraries/MLXLLM/Models/FalconH1.swift | 337 +++++++++++++------------ 1 file changed, 180 insertions(+), 157 deletions(-) diff --git a/Libraries/MLXLLM/Models/FalconH1.swift b/Libraries/MLXLLM/Models/FalconH1.swift index b947037b..bc07d694 100644 --- a/Libraries/MLXLLM/Models/FalconH1.swift +++ b/Libraries/MLXLLM/Models/FalconH1.swift @@ -47,7 +47,7 @@ private class RMSNormGated: Module { } private func computeMupVector(_ args: FalconH1Configuration) -> MLXArray { - let intermediateSize = args.mambaDSSM ?? Int(Float(args.mambaExpand) * Float(args.hiddenSize)) + let intermediateSize = args.mambaDSSM ?? args.mambaExpand * args.hiddenSize let groupsTimeStateSize = args.mambaNGroups * args.mambaDState let numHeads = args.mambaNHeads let zxbcdtMultipliers = args.ssmMultipliers @@ -55,17 +55,40 @@ private func computeMupVector(_ args: FalconH1Configuration) -> MLXArray { let vectorShape = 2 * intermediateSize + 2 * groupsTimeStateSize + numHeads let mupVector = MLXArray.ones([1, 1, vectorShape]) - mupVector[0..., 0..., .. MLXArray { var residual = hiddenStates @@ -253,9 +280,7 @@ private class DecoderLayer: Module { cache: cache ) * attnOutMultiplier - hiddenStates = mambaHiddenStates + attentionHiddenStates - - hiddenStates = residual + hiddenStates + hiddenStates = residual + mambaHiddenStates + attentionHiddenStates residual = hiddenStates hiddenStates = preFfLayerNorm(hiddenStates) @@ -270,8 +295,7 @@ private func applyMaskToPaddingStates(_ inputStates: MLXArray, _ attentionMask: -> MLXArray { if let attentionMask { - let mask = expandedDimensions(attentionMask, axes: [-1]) - return inputStates * mask + return inputStates * expandedDimensions(attentionMask, axis: -1) } return inputStates } @@ -280,7 +304,7 @@ private func padTensorBySize(_ tensor: MLXArray, _ padSize: Int) -> MLXArray { if padSize > 0 { var padShape = tensor.shape padShape[1] = padSize - let padding = MLXArray.zeros(padShape).asType(tensor.dtype) + let padding = MLXArray.zeros(padShape, dtype: tensor.dtype) return concatenated([tensor, padding], axis: 1) } return tensor @@ -292,28 +316,30 @@ private func reshapeIntoChunks(_ tensor: MLXArray, _ padSize: Int, _ chunkSize: tensor = padTensorBySize(tensor, padSize) } - let batchSize = tensor.shape[0] - let seqLen = tensor.dim(1) + let tensorShape = tensor.shape[..<2] + let batchSize = tensorShape[0] + let seqLen = tensorShape[1] let numChunks = seqLen / chunkSize var newShape = [batchSize, numChunks, chunkSize] - newShape.append(contentsOf: Array(tensor.shape[2...])) + newShape.append(contentsOf: tensor.shape[2...]) return tensor.reshaped(newShape) } private func segmentSum(_ inputTensor: MLXArray) -> MLXArray { let chunkSize = inputTensor.dim(-1) - var inputTensor = expandedDimensions(inputTensor, axes: [-1]) + var inputTensor = expandedDimensions(inputTensor, axis: -1) inputTensor = broadcast( - inputTensor, to: inputTensor.shape[0 ..< inputTensor.ndim - 1] + [chunkSize]) + inputTensor, to: inputTensor.shape.dropLast() + [chunkSize] + ) - var mask = tri(chunkSize, k: -1).asType(.bool) - inputTensor = MLX.where(mask, inputTensor, MLXArray.zeros(like: inputTensor)) + var mask = tri(chunkSize, k: -1, dtype: .bool) + inputTensor = MLX.where(mask, inputTensor, 0) let tensorSegsum = cumsum(inputTensor, axis: -2) - mask = tri(chunkSize, k: 0).asType(.bool) - return MLX.where(mask, tensorSegsum, MLXArray(-Float.infinity)) + mask = tri(chunkSize, k: 0, dtype: .bool) + return MLX.where(mask, tensorSegsum, -Float.infinity) } // MARK: - Mixer @@ -337,7 +363,7 @@ private class Mixer: Module { let timeStepMax: Float let convDim: Int let mambaRMSNorm: Bool - let norm: RMSNormGated? + var norm: RMSNormGated? = nil let ssmInMultiplier: Float let conv1d: Conv1d @@ -354,11 +380,11 @@ private class Mixer: Module { self.hiddenSize = args.hiddenSize self.ssmStateSize = args.mambaDState self.convKernelSize = args.mambaDConv - self.intermediateSize = args.mambaDSSM ?? Int(args.mambaExpand * args.hiddenSize) + self.intermediateSize = args.mambaDSSM ?? args.mambaExpand * args.hiddenSize self.layerIdx = layerIdx self.useConvBias = args.mambaConvBias self.useBias = args.mambaProjBias - self.layerNormEpsilon = args.rmsNormEps + self.layerNormEpsilon = args.rmsNormEps ?? 1e-5 self.groupsTimeStateSize = args.mambaNGroups * args.mambaDState self.nGroups = args.mambaNGroups self.headDim = args.mambaDHead @@ -387,7 +413,7 @@ private class Mixer: Module { _dtBias.wrappedValue = MLXArray.ones([numHeads]) - let A = MLXArray(Array(1 ... numHeads)).asType(.float32) + let A = MLXArray(Array(1 ..< numHeads + 1)).asType(.float32) _aLog.wrappedValue = log(A) @@ -399,8 +425,6 @@ private class Mixer: Module { nGroups: nGroups, normBeforeGate: args.mambaNormBeforeGate ) - } else { - self.norm = nil } _d.wrappedValue = MLXArray.ones([numHeads]) + 1.0 @@ -432,17 +456,22 @@ private class Mixer: Module { let gate = projectedStates[.ellipsis, .. 0 + let usePrecomputedStates: Bool = { + guard let cache, let cachePosition else { return false } + + return cache.hasPreviousState + && seqLen == 1 + && cache.convStates[layerIdx]?.shape[0] == batchSize + && cache.ssmStates[layerIdx]?.shape[0] == batchSize + && cachePosition[0].item() > 0 + }() if usePrecomputedStates, let cache { - var convState = roll(cache.convStates[layerIdx]!, shift: -1, axis: -1) + let convState = roll(cache.convStates[layerIdx]!, shift: -1, axis: -1) convState[0..., 0..., -1] = hiddenStatesBC[0..., 0, 0...] cache.convStates[layerIdx] = convState @@ -461,7 +490,8 @@ private class Mixer: Module { if padSize > 0 { padded( hiddenStatesBCTransposed, - widths: [.init((0, 0)), .init((0, 0)), .init((padSize, 0))]) + widths: [.init([0, 0]), .init([0, 0]), .init([padSize, 0])] + ) } else { hiddenStatesBCTransposed[0..., 0..., .. 0 { y = y[0..., .. MLXArray { - var h = embedTokens(inputs) + var h = embedTokens(inputs) * args.embeddingMultiplier + let mask = mask ?? createAttentionMask(h: h, cache: cache) + let cache: [Mamba2Cache?] = cache ?? Array(repeating: nil, count: layers.count) - h = h * args.embeddingMultiplier + var cachePosition = MLXArray(0 ..< h.dim(1)).asType(.int32) - let mask = mask ?? createAttentionMask(h: h, cache: nil) - let mambaMask: MLXArray? = nil - - let cachePosition = MLXArray(0 ..< h.dim(1)).asType(.int32) - - if h.dim(1) == 1, let cache { - let prevSeqlen = cache[0].keyCache[0].dim(-2) - let cachePosition = cachePosition + prevSeqlen + if h.dim(1) == 1, let c = cache[0] { + let prevSeqlen = c.keyCache[0].dim(-2) + cachePosition = cachePosition + prevSeqlen } - for (layer, c) in zip(layers, cache!) { + for (layer, c) in zip(layers, cache) { h = layer( h, - cache: c ?? Mamba2Cache(args), - mask: nil, - mambaMask: mambaMask, + cache: c, + mask: mask, cachePosition: cachePosition ) } @@ -775,7 +795,7 @@ public struct FalconH1Configuration: Codable, Sendable { var hiddenAct: String var hiddenSize: Int var initializerRange: Float - var intermediateSize: Int + var intermediateSize: Int? var keyMultiplier: Float var lmHeadMultiplier: Float var mambaChunkSize: Int @@ -802,7 +822,7 @@ public struct FalconH1Configuration: Codable, Sendable { var numLogitsToKeep: Int var padTokenId: Int var projectorsBias: Bool - var rmsNormEps: Float + var rmsNormEps: Float? var ropeTraditional: Bool var ropeScaling: Float? var ropeTheta: Float @@ -884,7 +904,7 @@ public struct FalconH1Configuration: Codable, Sendable { self.initializerRange = try container.decodeIfPresent(Float.self, forKey: .initializerRange) ?? 0.02 self.intermediateSize = - try container.decodeIfPresent(Int.self, forKey: .intermediateSize) ?? 14336 + try container.decodeIfPresent(Int.self, forKey: .intermediateSize) ?? nil self.keyMultiplier = try container.decodeIfPresent(Float.self, forKey: .keyMultiplier) ?? 1.0 self.lmHeadMultiplier = @@ -926,7 +946,7 @@ public struct FalconH1Configuration: Codable, Sendable { self.padTokenId = try container.decodeIfPresent(Int.self, forKey: .padTokenId) ?? 0 self.projectorsBias = try container.decodeIfPresent(Bool.self, forKey: .projectorsBias) ?? false - self.rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1e-05 + self.rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? nil self.ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false self.ropeScaling = try container.decodeIfPresent(Float?.self, forKey: .ropeScaling) ?? nil @@ -977,23 +997,26 @@ private class Mamba2Cache: KVCache { self.convKernelSize = args.mambaDConv self.intermediateSize = - args.mambaDSSM ?? Int(Float(args.mambaExpand) * Float(args.hiddenSize)) + args.mambaDSSM ?? args.mambaExpand * args.hiddenSize self.convStates = [:] self.ssmStates = [:] + let convStateShape = [ + batchSize, + intermediateSize + 2 * args.mambaNGroups * args.mambaDState, + convKernelSize, + ] + let ssmStateShape = [ + batchSize, + args.mambaNHeads, + args.mambaDHead, + args.mambaDState, + ] + for i in 0 ..< args.numHiddenLayers { - convStates[i] = MLXArray.zeros([ - batchSize, - intermediateSize + 2 * args.mambaNGroups * args.mambaDState, - convKernelSize, - ]) - ssmStates[i] = MLXArray.zeros([ - batchSize, - args.mambaNHeads, - args.mambaDHead, - args.mambaDState, - ]) + convStates[i] = MLXArray.zeros(convStateShape) + ssmStates[i] = MLXArray.zeros(ssmStateShape) } self.seqlenOffset = 0 @@ -1015,8 +1038,8 @@ private class Mamba2Cache: KVCache { if keyCache.count <= layerIdx { for _ in keyCache.count ..< layerIdx { - keyCache.append(MLXArray([])) - valueCache.append(MLXArray([])) + keyCache.append([]) + valueCache.append([]) } keyCache.append(keyStates) valueCache.append(valueStates) From 7900a9d1ba2817bdc124956bd5771f4663400a5c Mon Sep 17 00:00:00 2001 From: John Mai Date: Wed, 1 Oct 2025 12:05:47 +0800 Subject: [PATCH 3/7] Add 'mode' property to Quantization configuration --- Libraries/MLXLMCommon/BaseConfiguration.swift | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Libraries/MLXLMCommon/BaseConfiguration.swift b/Libraries/MLXLMCommon/BaseConfiguration.swift index e9c0ed18..83083f3d 100644 --- a/Libraries/MLXLMCommon/BaseConfiguration.swift +++ b/Libraries/MLXLMCommon/BaseConfiguration.swift @@ -21,6 +21,7 @@ public struct BaseConfiguration: Codable, Sendable { public var quantMethod: String? = nil public var linearClass: String? = nil public var quantizationMode: String? = nil + public var mode: String? = nil public var asTuple: (Int, Int) { (groupSize, bits) } @@ -30,6 +31,7 @@ public struct BaseConfiguration: Codable, Sendable { case quantMethod = "quant_method" case linearClass = "linear_class" case quantizationMode = "quantization_mode" + case mode = "mode" } } @@ -116,6 +118,7 @@ public struct BaseConfiguration: Codable, Sendable { case Quantization.CodingKeys.quantMethod.rawValue: continue case Quantization.CodingKeys.linearClass.rawValue: continue case Quantization.CodingKeys.quantizationMode.rawValue: continue + case Quantization.CodingKeys.mode.rawValue: continue default: if let f = try? container.decode(Bool.self, forKey: key) { From e9ac40bc192516209e40fd2af081b6be36cfaf9b Mon Sep 17 00:00:00 2001 From: John Mai Date: Wed, 1 Oct 2025 12:06:19 +0800 Subject: [PATCH 4/7] update --- Libraries/MLXLLM/Models/FalconH1.swift | 1249 +++++++++--------------- 1 file changed, 456 insertions(+), 793 deletions(-) diff --git a/Libraries/MLXLLM/Models/FalconH1.swift b/Libraries/MLXLLM/Models/FalconH1.swift index bc07d694..a79446ab 100644 --- a/Libraries/MLXLLM/Models/FalconH1.swift +++ b/Libraries/MLXLLM/Models/FalconH1.swift @@ -12,6 +12,194 @@ import MLXNN // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/falcon_h1.py + +// MARK: - Configuration + +public struct FalconH1Configuration: Codable, Sendable { + var attentionBias: Bool + var attentionDropout: Float + var attentionInMultiplier: Float + var attentionOutMultiplier: Float + var bosTokenId: Int + var embeddingMultiplier: Float + var eosTokenId: Int + var headDim: Int + var hiddenAct: String + var hiddenSize: Int + var initializerRange: Float + var intermediateSize: Int? + var keyMultiplier: Float + var lmHeadMultiplier: Float + var mambaChunkSize: Int + var mambaConvBias: Bool + var mambaDConv: Int + var mambaDHead: Int + var mambaDSSM: Int + var mambaDState: Int + var mambaExpand: Int + var mambaNGroups: Int + var mambaNHeads: Int + var mambaNormBeforeGate: Bool + var mambaProjBias: Bool + var mambaRMSNorm: Bool + var mambaUseMLP: Bool + var maxPositionEmbeddings: Int + var mlpBias: Bool + var mlpExpansionFactor: Int + var mlpMultipliers: [Float] + var modelType: String + var numAttentionHeads: Int + var numHiddenLayers: Int + var numKeyValueHeads: Int + var numLogitsToKeep: Int + var padTokenId: Int + var projectorsBias: Bool + var rmsNormEps: Float + var ropeTraditional: Bool + var ropeScaling: Float? + var ropeTheta: Float + var ssmInMultiplier: Float + var ssmMultipliers: [Float] + var ssmOutMultiplier: Float + var tieWordEmbeddings: Bool + var torchDtype: String + var vocabSize: Int + + enum CodingKeys: String, CodingKey { + case attentionBias = "attention_bias" + case attentionDropout = "attention_dropout" + case attentionInMultiplier = "attention_in_multiplier" + case attentionOutMultiplier = "attention_out_multiplier" + case bosTokenId = "bos_token_id" + case embeddingMultiplier = "embedding_multiplier" + case eosTokenId = "eos_token_id" + case headDim = "head_dim" + case hiddenAct = "hidden_act" + case hiddenSize = "hidden_size" + case initializerRange = "initializer_range" + case intermediateSize = "intermediate_size" + case keyMultiplier = "key_multiplier" + case lmHeadMultiplier = "lm_head_multiplier" + case mambaChunkSize = "mamba_chunk_size" + case mambaConvBias = "mamba_conv_bias" + case mambaDConv = "mamba_d_conv" + case mambaDHead = "mamba_d_head" + case mambaDSSM = "mamba_d_ssm" + case mambaDState = "mamba_d_state" + case mambaExpand = "mamba_expand" + case mambaNGroups = "mamba_n_groups" + case mambaNHeads = "mamba_n_heads" + case mambaNormBeforeGate = "mamba_norm_before_gate" + case mambaProjBias = "mamba_proj_bias" + case mambaRMSNorm = "mamba_rms_norm" + case mambaUseMLP = "mamba_use_mlp" + case maxPositionEmbeddings = "max_position_embeddings" + case mlpBias = "mlp_bias" + case mlpExpansionFactor = "mlp_expansion_factor" + case mlpMultipliers = "mlp_multipliers" + case modelType = "model_type" + case numAttentionHeads = "num_attention_heads" + case numHiddenLayers = "num_hidden_layers" + case numKeyValueHeads = "num_key_value_heads" + case numLogitsToKeep = "num_logits_to_keep" + case padTokenId = "pad_token_id" + case projectorsBias = "projectors_bias" + case rmsNormEps = "rms_norm_eps" + case ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + case ropeTheta = "rope_theta" + case ssmInMultiplier = "ssm_in_multiplier" + case ssmMultipliers = "ssm_multipliers" + case ssmOutMultiplier = "ssm_out_multiplier" + case tieWordEmbeddings = "tie_word_embeddings" + case torchDtype = "torch_dtype" + case vocabSize = "vocab_size" + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.attentionBias = + try container.decodeIfPresent(Bool.self, forKey: .attentionBias) ?? false + self.attentionDropout = + try container.decodeIfPresent(Float.self, forKey: .attentionDropout) ?? 0.0 + self.attentionInMultiplier = + try container.decodeIfPresent(Float.self, forKey: .attentionInMultiplier) ?? 1.0 + self.attentionOutMultiplier = + try container.decodeIfPresent(Float.self, forKey: .attentionOutMultiplier) ?? 1.0 + self.bosTokenId = try container.decodeIfPresent(Int.self, forKey: .bosTokenId) ?? 1 + self.embeddingMultiplier = + try container.decodeIfPresent(Float.self, forKey: .embeddingMultiplier) ?? 1.0 + self.eosTokenId = try container.decodeIfPresent(Int.self, forKey: .eosTokenId) ?? 2 + self.headDim = try container.decodeIfPresent(Int.self, forKey: .headDim) ?? 64 + self.hiddenAct = try container.decodeIfPresent(String.self, forKey: .hiddenAct) ?? "silu" + self.hiddenSize = try container.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 4096 + self.initializerRange = + try container.decodeIfPresent(Float.self, forKey: .initializerRange) ?? 0.02 + self.intermediateSize = + try container.decodeIfPresent(Int.self, forKey: .intermediateSize) ?? nil + self.keyMultiplier = + try container.decodeIfPresent(Float.self, forKey: .keyMultiplier) ?? 1.0 + self.lmHeadMultiplier = + try container.decodeIfPresent(Float.self, forKey: .lmHeadMultiplier) ?? 1.0 + self.mambaChunkSize = + try container.decodeIfPresent(Int.self, forKey: .mambaChunkSize) ?? 256 + self.mambaConvBias = + try container.decodeIfPresent(Bool.self, forKey: .mambaConvBias) ?? true + self.mambaDConv = try container.decodeIfPresent(Int.self, forKey: .mambaDConv) ?? 4 + self.mambaDHead = try container.decodeIfPresent(Int.self, forKey: .mambaDHead) ?? 64 + self.mambaDSSM = try container.decodeIfPresent(Int.self, forKey: .mambaDSSM) ?? 1536 + self.mambaDState = try container.decodeIfPresent(Int.self, forKey: .mambaDState) ?? 256 + self.mambaExpand = try container.decodeIfPresent(Int.self, forKey: .mambaExpand) ?? 2 + self.mambaNGroups = try container.decodeIfPresent(Int.self, forKey: .mambaNGroups) ?? 1 + self.mambaNHeads = try container.decodeIfPresent(Int.self, forKey: .mambaNHeads) ?? 128 + self.mambaNormBeforeGate = + try container.decodeIfPresent(Bool.self, forKey: .mambaNormBeforeGate) ?? true + self.mambaProjBias = + try container.decodeIfPresent(Bool.self, forKey: .mambaProjBias) ?? false + self.mambaRMSNorm = try container.decodeIfPresent(Bool.self, forKey: .mambaRMSNorm) ?? false + self.mambaUseMLP = try container.decodeIfPresent(Bool.self, forKey: .mambaUseMLP) ?? true + self.maxPositionEmbeddings = + try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 8192 + self.mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) ?? false + self.mlpExpansionFactor = + try container.decodeIfPresent(Int.self, forKey: .mlpExpansionFactor) ?? 8 + self.mlpMultipliers = + try container.decodeIfPresent([Float].self, forKey: .mlpMultipliers) ?? [1.0, 1.0] + self.modelType = + try container.decodeIfPresent(String.self, forKey: .modelType) ?? "falcon_h1" + self.numAttentionHeads = + try container.decodeIfPresent(Int.self, forKey: .numAttentionHeads) ?? 32 + self.numHiddenLayers = + try container.decodeIfPresent(Int.self, forKey: .numHiddenLayers) ?? 32 + self.numKeyValueHeads = + try container.decodeIfPresent(Int.self, forKey: .numKeyValueHeads) ?? 8 + self.numLogitsToKeep = + try container.decodeIfPresent(Int.self, forKey: .numLogitsToKeep) ?? 1 + self.padTokenId = try container.decodeIfPresent(Int.self, forKey: .padTokenId) ?? 0 + self.projectorsBias = + try container.decodeIfPresent(Bool.self, forKey: .projectorsBias) ?? false + self.rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1e-5 + self.ropeTraditional = + try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false + self.ropeScaling = try container.decodeIfPresent(Float?.self, forKey: .ropeScaling) ?? nil + self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 100000.0 + self.ssmInMultiplier = + try container.decodeIfPresent(Float.self, forKey: .ssmInMultiplier) ?? 1.0 + self.ssmMultipliers = + try container.decodeIfPresent([Float].self, forKey: .ssmMultipliers) ?? [ + 1.0, 1.0, 1.0, 1.0, 1.0, + ] + self.ssmOutMultiplier = + try container.decodeIfPresent(Float.self, forKey: .ssmOutMultiplier) ?? 1.0 + self.tieWordEmbeddings = + try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false + self.torchDtype = + try container.decodeIfPresent(String.self, forKey: .torchDtype) ?? "bfloat16" + self.vocabSize = try container.decodeIfPresent(Int.self, forKey: .vocabSize) ?? 128000 + } +} + + // MARK: - RMSNormGated private class RMSNormGated: Module { @@ -28,69 +216,40 @@ private class RMSNormGated: Module { } func callAsFunction(_ hiddenStates: MLXArray, gate: MLXArray? = nil) -> MLXArray { - let inputDtype = hiddenStates.dtype - var hiddenStates = hiddenStates if !normBeforeGate, let gate { - hiddenStates = hiddenStates * silu(gate.asType(.float16)) + hiddenStates = hiddenStates * silu(gate) } hiddenStates = MLXFast.rmsNorm(hiddenStates, weight: weight, eps: varianceEpsilon) if normBeforeGate, let gate { - hiddenStates = hiddenStates * silu(gate.asType(.float16)) + hiddenStates = hiddenStates * silu(gate) } - return hiddenStates.asType(inputDtype) + return hiddenStates } } private func computeMupVector(_ args: FalconH1Configuration) -> MLXArray { - let intermediateSize = args.mambaDSSM ?? args.mambaExpand * args.hiddenSize + let intermediateSize = args.mambaDSSM let groupsTimeStateSize = args.mambaNGroups * args.mambaDState let numHeads = args.mambaNHeads - let zxbcdtMultipliers = args.ssmMultipliers - - let vectorShape = 2 * intermediateSize + 2 * groupsTimeStateSize + numHeads - let mupVector = MLXArray.ones([1, 1, vectorShape]) - - mupVector[ - 0..., - 0..., - .. MLXArray + func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? = nil) -> MLXArray { let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2)) @@ -148,38 +303,19 @@ private class Attention: Module { var keys = kProj(x) var values = vProj(x) - keys = keys * keyMultiplier - queries = queries.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) keys = keys.reshaped(B, L, numKVHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, numKVHeads, -1).transposed(0, 2, 1, 3) if let cache { - queries = rope(queries, offset: cache.seqlenOffset) - keys = rope(keys, offset: cache.seqlenOffset) - (keys, values) = cache.update(keyStates: keys, valueStates: values, layerIdx: layerIdx) + queries = rope(queries, offset: cache.offset) + keys = rope(keys, offset: cache.offset) + (keys, values) = cache.update(keys: keys, values: values) } else { queries = rope(queries) keys = rope(keys) } - if var mask { - let kvSeqLen = keys.dim(-2) - if mask.ndim == 2 { - mask = mask[.newAxis, .newAxis, 0..., 0...] - } - - if kvSeqLen > L { - if mask.dim(-1) < kvSeqLen { - let numHeadsDim = mask.dim(1) > 1 ? mask.dim(1) : 1 - let padLength = kvSeqLen - mask.dim(-1) - let padShape = [B, numHeadsDim, L, padLength] - let padding = MLXArray.ones(padShape, dtype: mask.dtype) - mask = concatenated([padding, mask], axis: -1) - } - } - } - var output = MLXFast.scaledDotProductAttention( queries: queries, keys: keys, @@ -193,213 +329,59 @@ private class Attention: Module { } } -// MARK: - MLP +// MARK: - Mixer -private class MLP: Module, UnaryLayer { - @ModuleInfo(key: "gate_proj") var gateProj: Linear - @ModuleInfo(key: "up_proj") var upProj: Linear - @ModuleInfo(key: "down_proj") var downProj: Linear +private class Mixer: Module { + let numHeads: Int + let hiddenSize: Int + let ssmStateSize: Int + let convKernelSize: Int + let intermediateSize: Int + let useConvBias: Bool + let useBias: Bool + let layerNormEpsilon: Float + let groupsTimeStateSize: Int + let nGroups: Int + let headDim: Int + let chunkSize: Int + let timeStepLimit: (Float, Float) + let timeStepMin: Float + let timeStepMax: Float + let convDim: Int + let mambaRMSNorm: Bool + var norm: RMSNormGated? = nil + let ssmInMultiplier: Float + let conv1d: Conv1d - let gateMultiplier: Float - let downMultiplier: Float + @ModuleInfo(key: "in_proj") var inProj: Linear + @ParameterInfo(key: "dt_bias") var dtBias: MLXArray + @ParameterInfo(key: "A_log") var aLog: MLXArray + @ParameterInfo(key: "D") var d: MLXArray + @ModuleInfo(key: "out_proj") var outProj: Linear init(_ args: FalconH1Configuration) { - let hiddenSize = args.hiddenSize - let intermediateSize = args.intermediateSize ?? 4 * hiddenSize - - _gateProj.wrappedValue = Linear(hiddenSize, intermediateSize, bias: args.mlpBias) - _upProj.wrappedValue = Linear(hiddenSize, intermediateSize, bias: args.mlpBias) - _downProj.wrappedValue = Linear(intermediateSize, hiddenSize, bias: args.mlpBias) - - self.gateMultiplier = args.mlpMultipliers[0] - self.downMultiplier = args.mlpMultipliers[1] - } + self.numHeads = args.mambaNHeads + self.hiddenSize = args.hiddenSize + self.ssmStateSize = args.mambaDState + self.convKernelSize = args.mambaDConv + self.intermediateSize = args.mambaDSSM + self.useConvBias = args.mambaConvBias + self.useBias = args.mambaProjBias + self.layerNormEpsilon = args.rmsNormEps + self.groupsTimeStateSize = args.mambaNGroups * args.mambaDState + self.nGroups = args.mambaNGroups + self.headDim = args.mambaDHead + self.chunkSize = args.mambaChunkSize + self.timeStepLimit = (0.0, Float.infinity) + self.timeStepMin = 0.001 + self.timeStepMax = 0.1 - func callAsFunction(_ x: MLXArray) -> MLXArray { - let y = upProj(x) * silu(gateProj(x) * gateMultiplier) - return downProj(y) * downMultiplier - } -} - -// MARK: - DecoderLayer - -private class DecoderLayer: Module { - let mamba: Mixer - let channelsAttn: Int - let ssmOutMultiplier: Float - let attnOutMultiplier: Float - let attentionInMultiplier: Float - - @ModuleInfo(key: "feed_forward") var feedForward: MLP - @ModuleInfo(key: "self_attn") var attention: Attention - @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm - @ModuleInfo(key: "pre_ff_layernorm") var preFfLayerNorm: RMSNorm - - init(_ args: FalconH1Configuration, layerIdx: Int, mupVector: MLXArray) { - self.mamba = Mixer(args, layerIdx: layerIdx, mupVector: mupVector) - - let headDim = args.hiddenSize / args.numAttentionHeads - self.channelsAttn = args.numAttentionHeads * headDim + 2 * args.numKeyValueHeads * headDim - - self.attentionInMultiplier = args.attentionInMultiplier - self.ssmOutMultiplier = args.ssmOutMultiplier - self.attnOutMultiplier = args.attentionOutMultiplier - - _feedForward.wrappedValue = MLP(args) - _attention.wrappedValue = Attention(args, layerIdx: layerIdx) - _inputLayerNorm.wrappedValue = RMSNorm( - dimensions: args.hiddenSize, eps: args.rmsNormEps ?? 1e-5 - ) - _preFfLayerNorm.wrappedValue = RMSNorm( - dimensions: args.hiddenSize, eps: args.rmsNormEps ?? 1e-5 - ) - } - - func callAsFunction( - _ hiddenStates: MLXArray, - cache: Mamba2Cache?, - mask: MLXArray?, - mambaMask: MLXArray? = nil, - cachePosition: MLXArray - ) -> MLXArray { - var residual = hiddenStates - var hiddenStates = inputLayerNorm(hiddenStates) - - let mambaHiddenStates = - mamba( - hiddenStates, - cache: cache, - mask: mambaMask, - cachePosition: cachePosition - ) * ssmOutMultiplier - - let attentionHiddenStates = - attention( - hiddenStates * attentionInMultiplier, - mask: mask, - cache: cache - ) * attnOutMultiplier - - hiddenStates = residual + mambaHiddenStates + attentionHiddenStates - - residual = hiddenStates - hiddenStates = preFfLayerNorm(hiddenStates) - hiddenStates = feedForward(hiddenStates) - hiddenStates = residual + hiddenStates - - return hiddenStates - } -} - -private func applyMaskToPaddingStates(_ inputStates: MLXArray, _ attentionMask: MLXArray?) - -> MLXArray -{ - if let attentionMask { - return inputStates * expandedDimensions(attentionMask, axis: -1) - } - return inputStates -} - -private func padTensorBySize(_ tensor: MLXArray, _ padSize: Int) -> MLXArray { - if padSize > 0 { - var padShape = tensor.shape - padShape[1] = padSize - let padding = MLXArray.zeros(padShape, dtype: tensor.dtype) - return concatenated([tensor, padding], axis: 1) - } - return tensor -} - -private func reshapeIntoChunks(_ tensor: MLXArray, _ padSize: Int, _ chunkSize: Int) -> MLXArray { - var tensor = tensor - if padSize > 0 { - tensor = padTensorBySize(tensor, padSize) - } - - let tensorShape = tensor.shape[..<2] - let batchSize = tensorShape[0] - let seqLen = tensorShape[1] - let numChunks = seqLen / chunkSize - - var newShape = [batchSize, numChunks, chunkSize] - newShape.append(contentsOf: tensor.shape[2...]) - return tensor.reshaped(newShape) -} - -private func segmentSum(_ inputTensor: MLXArray) -> MLXArray { - let chunkSize = inputTensor.dim(-1) - var inputTensor = expandedDimensions(inputTensor, axis: -1) - inputTensor = broadcast( - inputTensor, to: inputTensor.shape.dropLast() + [chunkSize] - ) - - var mask = tri(chunkSize, k: -1, dtype: .bool) - inputTensor = MLX.where(mask, inputTensor, 0) - - let tensorSegsum = cumsum(inputTensor, axis: -2) - - mask = tri(chunkSize, k: 0, dtype: .bool) - return MLX.where(mask, tensorSegsum, -Float.infinity) -} - -// MARK: - Mixer - -private class Mixer: Module { - let numHeads: Int - let hiddenSize: Int - let ssmStateSize: Int - let convKernelSize: Int - let intermediateSize: Int - let layerIdx: Int - let useConvBias: Bool - let useBias: Bool - let layerNormEpsilon: Float - let groupsTimeStateSize: Int - let nGroups: Int - let headDim: Int - let chunkSize: Int - let timeStepLimit: (Float, Float) - let timeStepMin: Float - let timeStepMax: Float - let convDim: Int - let mambaRMSNorm: Bool - var norm: RMSNormGated? = nil - let ssmInMultiplier: Float - let conv1d: Conv1d - - let _mupVector: MLXArray - - @ModuleInfo(key: "in_proj") var inProj: Linear - @ParameterInfo(key: "dt_bias") var dtBias: MLXArray - @ParameterInfo(key: "A_log") var aLog: MLXArray - @ParameterInfo(key: "D") var d: MLXArray - @ModuleInfo(key: "out_proj") var outProj: Linear - - init(_ args: FalconH1Configuration, layerIdx: Int, mupVector: MLXArray) { - self.numHeads = args.mambaNHeads - self.hiddenSize = args.hiddenSize - self.ssmStateSize = args.mambaDState - self.convKernelSize = args.mambaDConv - self.intermediateSize = args.mambaDSSM ?? args.mambaExpand * args.hiddenSize - self.layerIdx = layerIdx - self.useConvBias = args.mambaConvBias - self.useBias = args.mambaProjBias - self.layerNormEpsilon = args.rmsNormEps ?? 1e-5 - self.groupsTimeStateSize = args.mambaNGroups * args.mambaDState - self.nGroups = args.mambaNGroups - self.headDim = args.mambaDHead - self.chunkSize = args.mambaChunkSize - self.timeStepLimit = (0.0, Float.infinity) - self.timeStepMin = 0.001 - self.timeStepMax = 0.1 - - self.convDim = intermediateSize + 2 * nGroups * ssmStateSize + self.convDim = intermediateSize + 2 * nGroups * ssmStateSize self.conv1d = Conv1d( inputChannels: convDim, outputChannels: convDim, kernelSize: convKernelSize, - padding: convKernelSize - 1, groups: convDim, bias: useConvBias ) @@ -413,7 +395,7 @@ private class Mixer: Module { _dtBias.wrappedValue = MLXArray.ones([numHeads]) - let A = MLXArray(Array(1 ..< numHeads + 1)).asType(.float32) + let A = MLXArray(Array(1 ..< numHeads + 1)) _aLog.wrappedValue = log(A) @@ -427,7 +409,7 @@ private class Mixer: Module { ) } - _d.wrappedValue = MLXArray.ones([numHeads]) + 1.0 + _d.wrappedValue = MLXArray.ones([numHeads]) _outProj.wrappedValue = Linear( intermediateSize, @@ -436,240 +418,195 @@ private class Mixer: Module { ) self.ssmInMultiplier = args.ssmInMultiplier - self._mupVector = mupVector } - func callAsFunction( - _ inputStates: MLXArray, cache: Mamba2Cache? = nil, mask: MLXArray? = nil, - cachePosition: MLXArray? = nil - ) -> MLXArray { - let (batchSize, seqLen, _) = (inputStates.dim(0), inputStates.dim(1), inputStates.dim(2)) - let dtype = inputStates.dtype - - let mask: MLXArray? = mask?[..<1, .ellipsis] - - var inputStates = applyMaskToPaddingStates(inputStates, mask) - - inputStates = inputStates * ssmInMultiplier - var projectedStates = inProj(inputStates) - projectedStates = projectedStates * _mupVector - - let gate = projectedStates[.ellipsis, .. 0 - }() - - if usePrecomputedStates, let cache { - let convState = roll(cache.convStates[layerIdx]!, shift: -1, axis: -1) - convState[0..., 0..., -1] = hiddenStatesBC[0..., 0, 0...] - cache.convStates[layerIdx] = convState - - hiddenStatesBC = sum(convState * squeezed(conv1d.weight, axis: -1), axis: -1) - if useConvBias { - hiddenStatesBC = hiddenStatesBC + conv1d.bias! - } - hiddenStatesBC = silu(hiddenStatesBC) + private func _applyConv(_ convInput: MLXArray, cache: MambaCache?) -> MLXArray { + let convState: MLXArray + if cache == nil || cache?[0] == nil { + convState = MLXArray.zeros( + [convInput.dim(0), convKernelSize - 1, convDim], + dtype: convInput.dtype + ) } else { - if let cache { - let hiddenStatesBCTransposed = hiddenStatesBC.transposed(0, 2, 1) - let seqLenTransposed: Int = hiddenStatesBCTransposed.dim(-1) - let padSize = convKernelSize - seqLenTransposed - - let convStates: MLXArray = - if padSize > 0 { - padded( - hiddenStatesBCTransposed, - widths: [.init([0, 0]), .init([0, 0]), .init([padSize, 0])] - ) - } else { - hiddenStatesBCTransposed[0..., 0..., .. (MLXArray, MLXArray) { + let (batchSize, seqLen, _) = (hiddenStates.dim(0), hiddenStates.dim(1), hiddenStates.dim(2)) + + let hiddenStates = hiddenStates.reshaped(batchSize, seqLen, numHeads, headDim) + let B = B.reshaped(batchSize, seqLen, nGroups, ssmStateSize) + let C = C.reshaped(batchSize, seqLen, nGroups, ssmStateSize) + + let (y, newState) = ssmUpdate( + hiddenStates: hiddenStates, + ALog: aLog, + B: B, + C: C, + D: d, + dt: dt, + dtBias: dtBias, + state: state, + timeStepLimit: timeStepLimit, + mask: mask + ) - c = c.reshaped(batchSize, nGroups, -1) - c = expandedDimensions(c, axis: 2) - c = broadcast(c, to: [batchSize, nGroups, numHeads / nGroups, c.dim(-1)]) - c = c.reshaped(batchSize, -1, c.dim(-1)) + return (y.reshaped(batchSize, seqLen, intermediateSize), newState) + } - let ssmStates = cache.ssmStates[layerIdx]!.asType(c.dtype) + func callAsFunction( + _ inputStates: MLXArray, cache: MambaCache? = nil, mask: MLXArray? = nil + ) -> MLXArray { + let projectedStates = inProj(inputStates) - let ssmStatesReshaped = ssmStates.reshaped(batchSize * numHeads, headDim, ssmStateSize) - let cReshaped = c.reshaped(batchSize * numHeads, ssmStateSize, 1) + let splits = MLX.split( + projectedStates, + indices: [intermediateSize, intermediateSize + convDim], + axis: -1 + ) + let gate = splits[0] + var convInput = splits[1] + let dt = splits[2] - y = matmul(ssmStatesReshaped, cReshaped) - y = y.reshaped(batchSize, numHeads, headDim) + if let mask = mask { + convInput = which(mask[.ellipsis, .newAxis], convInput, 0) + } + let convOutput = _applyConv(convInput, cache: cache) + + let convSplits = MLX.split( + convOutput, + indices: [ + intermediateSize, + intermediateSize + nGroups * ssmStateSize + ], + axis: -1 + ) + let hiddenStatesSSM = convSplits[0] + let B = convSplits[1] + let C = convSplits[2] - var d = expandedDimensions(d, axis: -1) - d = broadcast(d, to: [d.dim(0), headDim]) - y = y + hiddenStates * d + var state = cache?[1] + var y: MLXArray + (y, state) = _ssm( + hiddenStates: hiddenStatesSSM, + B: B, + C: C, + dt: dt, + state: state, + mask: mask + ) + if let cache = cache { + cache[1] = state + } - y = y.reshaped(batchSize, -1) - y = expandedDimensions(y, axis: 1) + if let norm = norm { + y = norm(y, gate: gate) } else { - var dt = softplus(dt + dtBias) - dt = clip(dt, min: timeStepLimit.0, max: timeStepLimit.1) - - hiddenStates = hiddenStates.reshaped(batchSize, seqLen, -1, headDim).asType(.float32) - b = b.reshaped(batchSize, seqLen, -1, ssmStateSize).asType(.float32) - c = c.reshaped(batchSize, seqLen, -1, ssmStateSize).asType(.float32) - - b = repeated(b, count: numHeads / nGroups, axis: 2) - c = repeated(c, count: numHeads / nGroups, axis: 2) - - let padSize = (chunkSize - seqLen % chunkSize) % chunkSize - - let dResidual = expandedDimensions(d, axis: -1) * padTensorBySize(hiddenStates, padSize) - - hiddenStates = hiddenStates * expandedDimensions(dt, axis: -1) - a = a.asType(hiddenStates.dtype) * dt - - hiddenStates = reshapeIntoChunks(hiddenStates, padSize, chunkSize) - a = reshapeIntoChunks(a, padSize, chunkSize) - b = reshapeIntoChunks(b, padSize, chunkSize) - c = reshapeIntoChunks(c, padSize, chunkSize) - - a = a.transposed(0, 3, 1, 2) - let aCumsum = cumsum(a, axis: -1) + y = y * silu(gate) + } - let L = exp(segmentSum(a)) + return outProj(y) + } +} - var cExpanded = expandedDimensions(c, axis: 3) - let bExpanded = expandedDimensions(b, axis: 2) - let gIntermediate = cExpanded * bExpanded - let g = sum(gIntermediate, axis: -1) +// MARK: - MLP - let LPermuted = L.transposed(0, 2, 3, 4, 1) - let MIntermediate = - expandedDimensions(g, axis: -1) * expandedDimensions(LPermuted, axis: -1) - let m = sum(MIntermediate, axis: -1) +private class MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gateProj: Linear + @ModuleInfo(key: "up_proj") var upProj: Linear + @ModuleInfo(key: "down_proj") var downProj: Linear - var hiddenStatesExpanded = expandedDimensions(hiddenStates, axis: 2) - let mExpanded = expandedDimensions(m, axis: -1) - let yDiag = sum(mExpanded * hiddenStatesExpanded, axis: 3) + let gateMultiplier: Float + let downMultiplier: Float - let decayStates = exp(aCumsum[0..., 0..., 0..., (-1)...] - aCumsum) - let decayStatesPermuted = decayStates.transposed(0, 2, 3, 1) - let bDecay = b * expandedDimensions(decayStatesPermuted, axis: -1) + init(_ args: FalconH1Configuration) { + let hiddenSize = args.hiddenSize + let intermediateSize = args.intermediateSize ?? 4 * hiddenSize - let bDecayExpanded = expandedDimensions(bDecay, axis: -2) - hiddenStatesExpanded = expandedDimensions(hiddenStates, axis: -1) - var states = sum(bDecayExpanded * hiddenStatesExpanded, axis: 2) + _gateProj.wrappedValue = Linear(hiddenSize, intermediateSize, bias: args.mlpBias) + _upProj.wrappedValue = Linear(hiddenSize, intermediateSize, bias: args.mlpBias) + _downProj.wrappedValue = Linear(intermediateSize, hiddenSize, bias: args.mlpBias) - let previousStates: MLXArray = - if usePrecomputedStates, let cache { - expandedDimensions(cache.ssmStates[layerIdx]!, axis: 1) - } else { - MLXArray.zeros(like: states[0..., ..<1]) - } + self.gateMultiplier = args.mlpMultipliers[0] + self.downMultiplier = args.mlpMultipliers[1] + } - states = concatenated([previousStates, states], axis: 1) + func callAsFunction(_ x: MLXArray) -> MLXArray { + let y = upProj(x) * silu(gateProj(x)) + return downProj(y) + } +} - let ACumsumLast = aCumsum[0..., 0..., 0..., -1] - let padded = padded(ACumsumLast, widths: [.init((0, 0)), .init((0, 0)), .init((1, 0))]) - var decayChunk = exp(segmentSum(padded)) - decayChunk = decayChunk.transposed(0, 3, 2, 1) +// MARK: - DecoderLayer - let decayExpanded = expandedDimensions( - expandedDimensions(decayChunk, axis: -1), axis: -1 - ) - var statesExpanded = expandedDimensions(states, axis: 2) - let newStates = sum(decayExpanded * statesExpanded, axis: 1) +private class DecoderLayer: Module { + @ModuleInfo(key: "feed_forward") var feedForward: MLP + @ModuleInfo(key: "mamba") var mamba: Mixer + @ModuleInfo(key: "self_attn") var attention: Attention + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "pre_ff_layernorm") var preFfLayerNorm: RMSNorm - states = newStates[0..., ..<(-1)] - let ssmState = newStates[0..., -1] + let channelsAttn: Int - let stateDecayOut = exp(aCumsum) - cExpanded = expandedDimensions(c, axis: -2) - statesExpanded = expandedDimensions(states, axis: 2) - let cTimesStates = cExpanded * statesExpanded + init(_ args: FalconH1Configuration) { + let headDim = args.headDim + self.channelsAttn = args.numAttentionHeads * headDim + 2 * args.numKeyValueHeads * headDim - let stateDecayOutPermuted = stateDecayOut.transposed(0, 2, 3, 1) - let cTimesStatesSum = sum(cTimesStates, axis: -1) - let yOff = cTimesStatesSum * expandedDimensions(stateDecayOutPermuted, axis: -1) + _feedForward.wrappedValue = MLP(args) + _mamba.wrappedValue = Mixer(args) + _attention.wrappedValue = Attention(args) + _inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps + ) + _preFfLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps + ) + } - y = yDiag + yOff - y = y.reshaped(batchSize, -1, numHeads, headDim) - y = y + dResidual + func callAsFunction( + _ h: MLXArray, + cache: CacheList?, + attnMask: MLXArray?, + mambaMask: MLXArray? + ) -> MLXArray { + var residual = h + var h = inputLayerNorm(h) - if padSize > 0 { - y = y[0..., .. MLXArray { - var h = embedTokens(inputs) * args.embeddingMultiplier - let mask = mask ?? createAttentionMask(h: h, cache: cache) - let cache: [Mamba2Cache?] = cache ?? Array(repeating: nil, count: layers.count) + var h = embedTokens(inputs) - var cachePosition = MLXArray(0 ..< h.dim(1)).asType(.int32) + let cache: [CacheList?] = cache ?? Array(repeating: nil, count: layers.count) - if h.dim(1) == 1, let c = cache[0] { - let prevSeqlen = c.keyCache[0].dim(-2) - cachePosition = cachePosition + prevSeqlen - } + let mambaMask = createSSMMask(h: h, cache: cache[0]?[0] as? MambaCache) + let attnMask: MLXArray? = createAttentionMask(h: h, cache: cache[0]?[1] != nil ? [cache[0]![1]] : nil) for (layer, c) in zip(layers, cache) { h = layer( h, cache: c, - mask: mask, - cachePosition: cachePosition + attnMask: attnMask, + mambaMask: mambaMask ) } @@ -747,9 +679,9 @@ public class FalconH1Model: Module, LLMModel, KVCacheDimensionProvider { } public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { - var out = model(inputs, cache: cache as? [Mamba2Cache]) + var out = model(inputs, cache: cache as? [CacheList]) if let lmHead { - out = lmHead(out) * configuration.lmHeadMultiplier + out = lmHead(out) } else { out = model.embedTokens.asLinear(out) } @@ -758,18 +690,46 @@ public class FalconH1Model: Module, LLMModel, KVCacheDimensionProvider { } public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { - var weights = weights - for (name, param) in weights { - if name.contains("conv1d.weight"), param.dim(-1) > param.dim(1) { - weights[name] = param.transposed(0, 2, 1) + let c1d = weights["model.layers.0.mamba.conv1d.weight"]! + if c1d.dim(-1) <= c1d.dim(1) { + return weights + } + + var sanitizedWeights = [String: MLXArray]() + let args = configuration + + for (name, var param) in weights { + if name.hasSuffix("embed_tokens.weight") { + param = param * args.embeddingMultiplier + } else if name.hasSuffix("lm_head.weight") { + param = param * args.lmHeadMultiplier + } else if name.hasSuffix("q_proj.weight") || name.hasSuffix("k_proj.weight") { + param = param * args.attentionInMultiplier + } else if name.hasSuffix("key_proj.weight") { + param = param * args.attentionInMultiplier * args.keyMultiplier + } else if name.hasSuffix("o_proj.weight") { + param = param * args.attentionOutMultiplier + } else if name.hasSuffix("out_proj.weight") { + param = param * args.ssmOutMultiplier + } else if name.hasSuffix("gate_proj.weight") { + param = param * args.mlpMultipliers[0] + } else if name.hasSuffix("down_proj.weight") { + param = param * args.mlpMultipliers[1] + } else if name.hasSuffix("in_proj.weight") { + let mupVector = computeMupVector(args) + param = param * (args.ssmInMultiplier * mupVector.asType(param.dtype)[0..., .newAxis]) + } else if name.contains("conv1d.weight") { + param = param.transposed(0, 2, 1) } + + sanitizedWeights[name] = param } - return weights + return sanitizedWeights } public func newCache(parameters: GenerateParameters?) -> [any KVCache] { - model.layers.map { _ in Mamba2Cache(configuration) } + model.layers.map { _ in CacheList(MambaCache(), KVCacheSimple()) } } } @@ -781,300 +741,3 @@ extension FalconH1Model: LoRAModel { } } -// MARK: - Configuration - -public struct FalconH1Configuration: Codable, Sendable { - var attentionBias: Bool - var attentionDropout: Float - var attentionInMultiplier: Float - var attentionOutMultiplier: Float - var bosTokenId: Int - var embeddingMultiplier: Float - var eosTokenId: Int - var headDim: Int? - var hiddenAct: String - var hiddenSize: Int - var initializerRange: Float - var intermediateSize: Int? - var keyMultiplier: Float - var lmHeadMultiplier: Float - var mambaChunkSize: Int - var mambaConvBias: Bool - var mambaDConv: Int - var mambaDHead: Int - var mambaDSSM: Int? - var mambaDState: Int - var mambaExpand: Int - var mambaNGroups: Int - var mambaNHeads: Int - var mambaNormBeforeGate: Bool - var mambaProjBias: Bool - var mambaRMSNorm: Bool - var mambaUseMLP: Bool - var maxPositionEmbeddings: Int - var mlpBias: Bool - var mlpExpansionFactor: Int - var mlpMultipliers: [Float] - var modelType: String - var numAttentionHeads: Int - var numHiddenLayers: Int - var numKeyValueHeads: Int - var numLogitsToKeep: Int - var padTokenId: Int - var projectorsBias: Bool - var rmsNormEps: Float? - var ropeTraditional: Bool - var ropeScaling: Float? - var ropeTheta: Float - var ssmInMultiplier: Float - var ssmMultipliers: [Float] - var ssmOutMultiplier: Float - var tieWordEmbeddings: Bool - var torchDtype: String - var vocabSize: Int - - enum CodingKeys: String, CodingKey { - case attentionBias = "attention_bias" - case attentionDropout = "attention_dropout" - case attentionInMultiplier = "attention_in_multiplier" - case attentionOutMultiplier = "attention_out_multiplier" - case bosTokenId = "bos_token_id" - case embeddingMultiplier = "embedding_multiplier" - case eosTokenId = "eos_token_id" - case headDim = "head_dim" - case hiddenAct = "hidden_act" - case hiddenSize = "hidden_size" - case initializerRange = "initializer_range" - case intermediateSize = "intermediate_size" - case keyMultiplier = "key_multiplier" - case lmHeadMultiplier = "lm_head_multiplier" - case mambaChunkSize = "mamba_chunk_size" - case mambaConvBias = "mamba_conv_bias" - case mambaDConv = "mamba_d_conv" - case mambaDHead = "mamba_d_head" - case mambaDSSM = "mamba_d_ssm" - case mambaDState = "mamba_d_state" - case mambaExpand = "mamba_expand" - case mambaNGroups = "mamba_n_groups" - case mambaNHeads = "mamba_n_heads" - case mambaNormBeforeGate = "mamba_norm_before_gate" - case mambaProjBias = "mamba_proj_bias" - case mambaRMSNorm = "mamba_rms_norm" - case mambaUseMLP = "mamba_use_mlp" - case maxPositionEmbeddings = "max_position_embeddings" - case mlpBias = "mlp_bias" - case mlpExpansionFactor = "mlp_expansion_factor" - case mlpMultipliers = "mlp_multipliers" - case modelType = "model_type" - case numAttentionHeads = "num_attention_heads" - case numHiddenLayers = "num_hidden_layers" - case numKeyValueHeads = "num_key_value_heads" - case numLogitsToKeep = "num_logits_to_keep" - case padTokenId = "pad_token_id" - case projectorsBias = "projectors_bias" - case rmsNormEps = "rms_norm_eps" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case ropeTheta = "rope_theta" - case ssmInMultiplier = "ssm_in_multiplier" - case ssmMultipliers = "ssm_multipliers" - case ssmOutMultiplier = "ssm_out_multiplier" - case tieWordEmbeddings = "tie_word_embeddings" - case torchDtype = "torch_dtype" - case vocabSize = "vocab_size" - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - self.attentionBias = - try container.decodeIfPresent(Bool.self, forKey: .attentionBias) ?? false - self.attentionDropout = - try container.decodeIfPresent(Float.self, forKey: .attentionDropout) ?? 0.0 - self.attentionInMultiplier = - try container.decodeIfPresent(Float.self, forKey: .attentionInMultiplier) ?? 1.0 - self.attentionOutMultiplier = - try container.decodeIfPresent(Float.self, forKey: .attentionOutMultiplier) ?? 1.0 - self.bosTokenId = try container.decodeIfPresent(Int.self, forKey: .bosTokenId) ?? 1 - self.embeddingMultiplier = - try container.decodeIfPresent(Float.self, forKey: .embeddingMultiplier) ?? 1.0 - self.eosTokenId = try container.decodeIfPresent(Int.self, forKey: .eosTokenId) ?? 2 - self.headDim = try container.decodeIfPresent(Int.self, forKey: .headDim) ?? nil - self.hiddenAct = try container.decodeIfPresent(String.self, forKey: .hiddenAct) ?? "silu" - self.hiddenSize = try container.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 4096 - self.initializerRange = - try container.decodeIfPresent(Float.self, forKey: .initializerRange) ?? 0.02 - self.intermediateSize = - try container.decodeIfPresent(Int.self, forKey: .intermediateSize) ?? nil - self.keyMultiplier = - try container.decodeIfPresent(Float.self, forKey: .keyMultiplier) ?? 1.0 - self.lmHeadMultiplier = - try container.decodeIfPresent(Float.self, forKey: .lmHeadMultiplier) ?? 1.0 - self.mambaChunkSize = - try container.decodeIfPresent(Int.self, forKey: .mambaChunkSize) ?? 256 - self.mambaConvBias = - try container.decodeIfPresent(Bool.self, forKey: .mambaConvBias) ?? true - self.mambaDConv = try container.decodeIfPresent(Int.self, forKey: .mambaDConv) ?? 4 - self.mambaDHead = try container.decodeIfPresent(Int.self, forKey: .mambaDHead) ?? 64 - self.mambaDSSM = try container.decodeIfPresent(Int.self, forKey: .mambaDSSM) ?? nil - self.mambaDState = try container.decodeIfPresent(Int.self, forKey: .mambaDState) ?? 256 - self.mambaExpand = try container.decodeIfPresent(Int.self, forKey: .mambaExpand) ?? 2 - self.mambaNGroups = try container.decodeIfPresent(Int.self, forKey: .mambaNGroups) ?? 1 - self.mambaNHeads = try container.decodeIfPresent(Int.self, forKey: .mambaNHeads) ?? 128 - self.mambaNormBeforeGate = - try container.decodeIfPresent(Bool.self, forKey: .mambaNormBeforeGate) ?? true - self.mambaProjBias = - try container.decodeIfPresent(Bool.self, forKey: .mambaProjBias) ?? false - self.mambaRMSNorm = try container.decodeIfPresent(Bool.self, forKey: .mambaRMSNorm) ?? false - self.mambaUseMLP = try container.decodeIfPresent(Bool.self, forKey: .mambaUseMLP) ?? true - self.maxPositionEmbeddings = - try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 8192 - self.mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) ?? false - self.mlpExpansionFactor = - try container.decodeIfPresent(Int.self, forKey: .mlpExpansionFactor) ?? 8 - self.mlpMultipliers = - try container.decodeIfPresent([Float].self, forKey: .mlpMultipliers) ?? [1.0, 1.0] - self.modelType = - try container.decodeIfPresent(String.self, forKey: .modelType) ?? "falcon_h1" - self.numAttentionHeads = - try container.decodeIfPresent(Int.self, forKey: .numAttentionHeads) ?? 32 - self.numHiddenLayers = - try container.decodeIfPresent(Int.self, forKey: .numHiddenLayers) ?? 32 - self.numKeyValueHeads = - try container.decodeIfPresent(Int.self, forKey: .numKeyValueHeads) ?? 8 - self.numLogitsToKeep = - try container.decodeIfPresent(Int.self, forKey: .numLogitsToKeep) ?? 1 - self.padTokenId = try container.decodeIfPresent(Int.self, forKey: .padTokenId) ?? 0 - self.projectorsBias = - try container.decodeIfPresent(Bool.self, forKey: .projectorsBias) ?? false - self.rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? nil - self.ropeTraditional = - try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false - self.ropeScaling = try container.decodeIfPresent(Float?.self, forKey: .ropeScaling) ?? nil - self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 100000.0 - self.ssmInMultiplier = - try container.decodeIfPresent(Float.self, forKey: .ssmInMultiplier) ?? 1.0 - self.ssmMultipliers = - try container.decodeIfPresent([Float].self, forKey: .ssmMultipliers) ?? [ - 1.0, 1.0, 1.0, 1.0, 1.0, - ] - self.ssmOutMultiplier = - try container.decodeIfPresent(Float.self, forKey: .ssmOutMultiplier) ?? 1.0 - self.tieWordEmbeddings = - try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false - self.torchDtype = - try container.decodeIfPresent(String.self, forKey: .torchDtype) ?? "bfloat16" - self.vocabSize = try container.decodeIfPresent(Int.self, forKey: .vocabSize) ?? 128000 - } -} - -// MARK: - Mamba2Cache KVCache - -private class Mamba2Cache: KVCache { - var offset: Int - - var maxSize: Int? - - func innerState() -> [MLXArray] { - [] - } - - var seqlenOffset: Int = 0 - var hasPreviousState: Bool = false - let convKernelSize: Int - - private var _seenTokens: Int = 0 - - let intermediateSize: Int - - var convStates: [Int: MLXArray] - var ssmStates: [Int: MLXArray] - - var transformerLayers: [Int] - var keyCache: [MLXArray] - var valueCache: [MLXArray] - - init(_ args: FalconH1Configuration, batchSize: Int = 1) { - self.convKernelSize = args.mambaDConv - - self.intermediateSize = - args.mambaDSSM ?? args.mambaExpand * args.hiddenSize - - self.convStates = [:] - self.ssmStates = [:] - - let convStateShape = [ - batchSize, - intermediateSize + 2 * args.mambaNGroups * args.mambaDState, - convKernelSize, - ] - let ssmStateShape = [ - batchSize, - args.mambaNHeads, - args.mambaDHead, - args.mambaDState, - ] - - for i in 0 ..< args.numHiddenLayers { - convStates[i] = MLXArray.zeros(convStateShape) - ssmStates[i] = MLXArray.zeros(ssmStateShape) - } - - self.seqlenOffset = 0 - self.hasPreviousState = false - self.transformerLayers = Array(0 ..< args.numHiddenLayers) - self.keyCache = [] - self.valueCache = [] - self.offset = 0 - } - - func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { - update(keyStates: keys, valueStates: values, layerIdx: 0) - } - - func update(keyStates: MLXArray, valueStates: MLXArray, layerIdx: Int) -> (MLXArray, MLXArray) { - if layerIdx == 0 { - _seenTokens += keyStates.dim(-2) - } - - if keyCache.count <= layerIdx { - for _ in keyCache.count ..< layerIdx { - keyCache.append([]) - valueCache.append([]) - } - keyCache.append(keyStates) - valueCache.append(valueStates) - } else if keyCache[layerIdx].size == 0 { - keyCache[layerIdx] = keyStates - valueCache[layerIdx] = valueStates - } else { - keyCache[layerIdx] = concatenated([keyCache[layerIdx], keyStates], axis: -2) - valueCache[layerIdx] = concatenated([valueCache[layerIdx], valueStates], axis: -2) - } - - return (keyCache[layerIdx], valueCache[layerIdx]) - } - - func updateConvState(layerIdx: Int, newConvState: MLXArray, cachePosition: MLXArray) -> MLXArray - { - var convState = convStates[layerIdx]! - let cachePosition = clip(cachePosition, min: 0, max: convKernelSize - 1) - - convState = roll(convState, shift: -1, axis: -1) - - if cachePosition.count > 1 { - convState[0..., 0..., 0...] = newConvState.transposed(0, 2, 1) - } else { - convState[0..., 0..., -1] = newConvState[0..., 0..., -1] - } - - convStates[layerIdx] = convState - return convStates[layerIdx]! - } - - func reset() { - for i in 0 ..< convStates.count { - convStates[i] = MLXArray.zeros(like: convStates[i]!) - ssmStates[i] = MLXArray.zeros(like: ssmStates[i]!) - } - } -} From 7958e18c118b47baec5c15f1b3a1b70eeeabf5af Mon Sep 17 00:00:00 2001 From: John Mai Date: Wed, 1 Oct 2025 12:31:48 +0800 Subject: [PATCH 5/7] Add ArraysCache --- Libraries/MLXLMCommon/KVCache.swift | 81 ++++++++++++++++++----------- 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/Libraries/MLXLMCommon/KVCache.swift b/Libraries/MLXLMCommon/KVCache.swift index 75559c8f..d3fe3eb8 100644 --- a/Libraries/MLXLMCommon/KVCache.swift +++ b/Libraries/MLXLMCommon/KVCache.swift @@ -206,6 +206,13 @@ public func createAttentionMask(h: MLXArray, cache: [KVCache]?, returnArray: Boo return .none } +public func createSSMMask(h: MLXArray, cache: MambaCache?) -> MLXArray? { + if let cache { + return cache.makeMask(N: h.dim(1)) + } + return nil +} + /// Standard KV cache implementation based on Python's KVCache /// See https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/base.py#L11 public class KVCacheSimple: BaseKVCache, CustomDebugStringConvertible { @@ -887,11 +894,14 @@ public class ChunkedKVCache: KVCacheSimple { } } -/// Simple cache for Mamba-style state space models -public class MambaCache: BaseKVCache { - private var cache: [MLXArray?] = [nil, nil] +/// Base cache for array-based state storage +public class ArraysCache: BaseKVCache { + private var cache: [MLXArray?] + private var leftPadding: MLXArray? - public override init() { + public init(size: Int, leftPadding: [Int]? = nil) { + self.cache = Array(repeating: nil, count: size) + self.leftPadding = leftPadding.map { MLXArray($0) } super.init() } @@ -904,39 +914,48 @@ public class MambaCache: BaseKVCache { set { cache[index] = newValue } } - public override func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { - // Mamba doesn't use traditional KV cache update pattern - fatalError("MambaCache should not use update(keys:values:) - use subscript access instead") - } - public override var state: [MLXArray] { get { - // Need to preserve the structure including nils, similar to Python version - // Use empty arrays as placeholders for nil values - var result: [MLXArray] = [] - for item in cache { - if let array = item { - result.append(array) - } else { - // Use an empty array as placeholder for nil (this shape should never occur naturally) - result.append(MLXArray.zeros([0], dtype: .float32)) - } - } - return result + return cache.compactMap { $0 } } set { - guard newValue.count == cache.count else { - fatalError("MambaCache state must have exactly \(cache.count) elements") - } - for (i, array) in newValue.enumerated() { - // Check if this is our nil placeholder (empty array with size 0) - if array.size == 0 { - cache[i] = nil - } else { - cache[i] = array - } + cache = newValue.map { $0 as MLXArray? } + } + } + + /// In-place filter to keep just the given indices in the cache + public func filter(batchIndices: MLXArray) { + cache = cache.map { c in + c?[batchIndices] + } + leftPadding = nil + } + + /// In-place extend this cache with the other cache + public func extend(other: ArraysCache) { + cache = zip(cache, other.cache).map { (c, o) in + if let c = c, let o = o { + return MLX.concatenated([c, o]) } + return c ?? o } + leftPadding = nil + } + + /// Create attention mask based on left padding + public func makeMask(N: Int) -> MLXArray? { + if cache[0] == nil, let leftPadding = leftPadding { + return MLXArray(0 ..< N) .>= leftPadding[0..., .newAxis] + } else { + return nil + } + } +} + +/// Simple cache for Mamba-style state space models +public class MambaCache: ArraysCache { + public init(leftPadding: [Int]? = nil) { + super.init(size: 2, leftPadding: leftPadding) } } From 74d3e225281255d6c632d2bb2d1c1fc001053387 Mon Sep 17 00:00:00 2001 From: John Mai Date: Wed, 1 Oct 2025 12:56:59 +0800 Subject: [PATCH 6/7] Add SSM --- Libraries/MLXLLM/Models/SSM.swift | 242 ++++++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 Libraries/MLXLLM/Models/SSM.swift diff --git a/Libraries/MLXLLM/Models/SSM.swift b/Libraries/MLXLLM/Models/SSM.swift new file mode 100644 index 00000000..0111a64b --- /dev/null +++ b/Libraries/MLXLLM/Models/SSM.swift @@ -0,0 +1,242 @@ +// +// SSM.swift +// mlx-swift-examples +// +// Created by John Mai on 2025/10/01. +// + +import Foundation +import MLX +import MLXFast +import MLXNN + +public func computeDt(_ dt: MLXArray, _ dtBias: MLXArray, _ timeStepLimit: (Float, Float)) + -> MLXArray +{ + let dt = softplus(dt + dtBias) + return MLX.clip(dt, min: timeStepLimit.0, max: timeStepLimit.1) +} + +private func makeSSMKernel() -> MLXFast.MLXFastKernel? { + let source = """ + auto n = thread_position_in_grid.z; + auto h_idx = n % H; + auto g_idx = n / G; + constexpr int n_per_t = Ds / 32; + + auto x = X + n * Dh; + out += n * Dh; + auto i_state = state_in + n * Dh * Ds; + auto o_state = state_out + n * Dh * Ds; + + // C and B have shape [batch, group, state_dim] + // C and B need to be offset by group size + auto C_ = C + g_idx * Ds; + auto B_ = B + g_idx * Ds; + + auto ds_idx = thread_position_in_threadgroup.x; + auto d_idx = thread_position_in_grid.y; + + auto dt_ = static_cast(dt[n]); + auto A = -fast::exp(static_cast(A_log[h_idx])); + auto dA = fast::exp(A * dt_); + + float acc = 0.0; + auto x_ = static_cast(x[d_idx]); + + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * ds_idx + i; + auto idx = d_idx * Ds + s_idx; + auto dB_by_x = x_ * dt_ * static_cast(B_[s_idx]); + auto state = dA * i_state[idx] + dB_by_x; + o_state[idx] = static_cast(state); + acc += state * C_[s_idx]; + } + acc = simd_sum(acc); + if (thread_index_in_simdgroup == 0) { + out[d_idx] = static_cast(acc + x_ * D[h_idx]); + } + """ + + return MLXFast.metalKernel( + name: "ssm_kernel", + inputNames: ["X", "A_log", "B", "C", "D", "dt", "state_in"], + outputNames: ["out", "state_out"], + source: source + ) +} + +private final class SSMKernelManager: @unchecked Sendable { + static let shared = SSMKernelManager() + + let ssmKernel: MLXFast.MLXFastKernel? + + private init() { + ssmKernel = makeSSMKernel() + } +} + + +func ssmUpdateKernel( + hiddenStates: MLXArray, + ALog: MLXArray, + B: MLXArray, + C: MLXArray, + D: MLXArray, + dt: MLXArray, + dtBias: MLXArray, + state: MLXArray, + timeStepLimit: (Float, Float) +) -> (MLXArray, MLXArray) { + let (n, _, h, d) = hiddenStates.shape4 + let inputType = hiddenStates.dtype + let (hb, ds) = (B.dim(-2), B.dim(-1)) + + let dt = computeDt(dt, dtBias, timeStepLimit) + + guard let kernel = SSMKernelManager.shared.ssmKernel else { + fatalError("SSM kernel not available") + } + + let outputs = kernel( + [hiddenStates, ALog, B, C, D, dt, state], + template: [ + ("T", inputType), + ("Dh", d), + ("Ds", ds), + ("H", h), + ("G", h / hb), + ], + grid: (32, d, h * n), + threadGroup: (32, 8, 1), + outputShapes: [[n, 1, h, d], state.shape], + outputDTypes: [inputType, inputType] + ) + + return (outputs[0], outputs[1]) +} + + +public func segsum(_ x: MLXArray, mask: MLXArray? = nil) -> MLXArray { + let l = x.dim(-1) + var x = x + + if let mask = mask { + let mask = MLX.expandedDimensions(mask, axis: 1) + x = x * mask + } + + x = MLX.repeated(x[.ellipsis, .newAxis], count: l, axis: -1) + x = MLX.tril(x, k: -1) + var xSegsum = MLX.cumsum(x, axis: -2) + + if let mask = mask { + xSegsum = which( + mask[.ellipsis, .newAxis, 0...] * mask[.ellipsis, .newAxis], + xSegsum, + MLXArray(-Float.infinity) + ) + } + + return xSegsum +} + + +public func ssmAttn( + x: MLXArray, + ALog: MLXArray, + B: MLXArray, + C: MLXArray, + D: MLXArray, + dt: MLXArray, + dtBias: MLXArray, + state: MLXArray? = nil, + timeStepLimit: (Float, Float) = (0.001, 100.0), + mask: MLXArray? = nil +) -> (MLXArray, MLXArray) { + let (b, l, h, dh) = x.shape4 + let (_, _, g, d) = B.shape4 + + let dt = computeDt(dt, dtBias, timeStepLimit) + let repeats = h / g + let A = -MLX.exp(ALog) + var B = MLX.transposed(B, axes: [0, 2, 3, 1]) + + // A * s + B * C + var CB = MLX.swappedAxes(C, 1, 2).matmul(B) + CB = MLX.repeated(CB, count: repeats, axis: 1) + + let dtA = dt * A.reshaped(1, 1, -1) + var decay = MLX.exp(segsum(dtA.swappedAxes(1, 2), mask: mask)) + + let surrogateAttentionMatrix = MLX.tril(CB * decay, k: 0) + + let dtx = dt.reshaped(b, l, h, 1) * x + var y = surrogateAttentionMatrix.matmul(dtx.swappedAxes(1, 2)) + y = MLX.swappedAxes(y, 1, 2) + + decay = decay[0..., 0..., (-1)..., 0...].transposed(0, 3, 1, 2) + B = MLX.repeated(B, count: h / g, axis: 1).swappedAxes(2, 3) + var dtxdecay = dtx * decay + dtxdecay = dtxdecay.swappedAxes(1, 2).swappedAxes(2, 3) + + var nextState = dtxdecay.matmul(B) + + if var state = state { + let expDtACumsum = MLX.exp(MLX.cumsum(dtA, axis: -2)) + nextState = nextState + expDtACumsum[0..., -1, 0..., .newAxis, .newAxis] * state + state = state.reshaped(b, 1, g, repeats, dh, d) + let C = C.reshaped(b, l, g, 1, d, 1) + let yPrev = (state.matmul(C)).squeezed(axis: -1).flattened(start: 2, end: 3) + y = y + expDtACumsum[.ellipsis, .newAxis] * yPrev + } + + y = y + x * D.reshaped(1, 1, h, 1) + return (y, nextState) +} + +public func ssmUpdate( + hiddenStates: MLXArray, + ALog: MLXArray, + B: MLXArray, + C: MLXArray, + D: MLXArray, + dt: MLXArray, + dtBias: MLXArray, + state: MLXArray? = nil, + timeStepLimit: (Float, Float) = (0.001, 100.0), + mask: MLXArray? = nil +) -> (MLXArray, MLXArray) { + let seqLen = hiddenStates.dim(1) + + + if seqLen == 1, + let state = state, + SSMKernelManager.shared.ssmKernel != nil + { + return ssmUpdateKernel( + hiddenStates: hiddenStates, + ALog: ALog, + B: B, + C: C, + D: D, + dt: dt, + dtBias: dtBias, + state: state, + timeStepLimit: timeStepLimit + ) + } else { + return ssmAttn( + x: hiddenStates, + ALog: ALog, + B: B, + C: C, + D: D, + dt: dt, + dtBias: dtBias, + state: state, + timeStepLimit: timeStepLimit, + mask: mask + ) + } +} From 47e7a20a7a2ce9ce2f7561d9c68f04a3f8b68ab0 Mon Sep 17 00:00:00 2001 From: John Mai Date: Wed, 1 Oct 2025 13:22:20 +0800 Subject: [PATCH 7/7] update --- Libraries/MLXLLM/Models/FalconH1.swift | 63 ++++++++++++++++---------- Libraries/MLXLLM/Models/SSM.swift | 4 -- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/Libraries/MLXLLM/Models/FalconH1.swift b/Libraries/MLXLLM/Models/FalconH1.swift index a79446ab..3040960e 100644 --- a/Libraries/MLXLLM/Models/FalconH1.swift +++ b/Libraries/MLXLLM/Models/FalconH1.swift @@ -12,7 +12,6 @@ import MLXNN // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/falcon_h1.py - // MARK: - Configuration public struct FalconH1Configuration: Codable, Sendable { @@ -199,7 +198,6 @@ public struct FalconH1Configuration: Codable, Sendable { } } - // MARK: - RMSNormGated private class RMSNormGated: Module { @@ -242,7 +240,7 @@ private func computeMupVector(_ args: FalconH1Configuration) -> MLXArray { intermediateSize, groupsTimeStateSize, groupsTimeStateSize, - numHeads + numHeads, ] let segments = zip(sizes, args.ssmMultipliers).map { size, multiplier in @@ -295,8 +293,7 @@ private class Attention: Module { ) } - func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? = nil) -> MLXArray - { + func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? = nil) -> MLXArray { let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2)) var queries = qProj(x) @@ -494,7 +491,7 @@ private class Mixer: Module { convOutput, indices: [ intermediateSize, - intermediateSize + nGroups * ssmStateSize + intermediateSize + nGroups * ssmStateSize, ], axis: -1 ) @@ -606,6 +603,25 @@ private class DecoderLayer: Module { } } +// MARK: - Helper Functions + +private func createSSMMask(h: MLXArray, cache: ArraysCache?) -> MLXArray? { + if let cache = cache { + return cache.makeMask(N: h.dim(1)) + } + return nil +} + +private func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? { + let N = h.dim(1) + // If cache exists and can make masks, use it + // Otherwise for single token, no mask needed + // For multi-token, SDPA will handle causal mask internally when nil + if N == 1 { + return nil + } + return nil // Will be handled by SDPA internally when nil +} // MARK: - Model @@ -614,7 +630,8 @@ private class ModelInner: Module { let vocabSize: Int let hiddenSize: Int - fileprivate let layers: [DecoderLayer] + let _mupVector: MLXArray + let layers: [DecoderLayer] @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding @ModuleInfo(key: "final_layernorm") var finalLayerNorm: RMSNorm @@ -624,10 +641,9 @@ private class ModelInner: Module { self.vocabSize = args.vocabSize self.hiddenSize = args.hiddenSize - precondition(vocabSize > 0) - _embedTokens.wrappedValue = Embedding(embeddingCount: vocabSize, dimensions: hiddenSize) + self._mupVector = computeMupVector(args) self.layers = (0 ..< args.numHiddenLayers).map { _ in DecoderLayer(args) } @@ -643,7 +659,8 @@ private class ModelInner: Module { let cache: [CacheList?] = cache ?? Array(repeating: nil, count: layers.count) let mambaMask = createSSMMask(h: h, cache: cache[0]?[0] as? MambaCache) - let attnMask: MLXArray? = createAttentionMask(h: h, cache: cache[0]?[1] != nil ? [cache[0]![1]] : nil) + let attnMask: MLXArray? = createAttentionMask( + h: h, cache: cache[0]?[1] != nil ? [cache[0]![1]] : nil) for (layer, c) in zip(layers, cache) { h = layer( @@ -665,7 +682,7 @@ public class FalconH1Model: Module, LLMModel, KVCacheDimensionProvider { private let model: ModelInner let configuration: FalconH1Configuration - @ModuleInfo(key: "lm_head") var lmHead: Linear? + @ModuleInfo(key: "lm_head") var lmHead: Linear public init(_ args: FalconH1Configuration) { self.configuration = args @@ -673,20 +690,18 @@ public class FalconH1Model: Module, LLMModel, KVCacheDimensionProvider { self.kvHeads = (0 ..< args.numKeyValueHeads).map { _ in args.numHiddenLayers } self.model = ModelInner(args) - if !args.tieWordEmbeddings { - _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabSize, bias: false) - } + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabSize, bias: false) } public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { - var out = model(inputs, cache: cache as? [CacheList]) - if let lmHead { - out = lmHead(out) - } else { - out = model.embedTokens.asLinear(out) - } + let out = model(inputs, cache: cache as? [CacheList]) + return lmHead(out) + } - return out + public func makeCache() -> [CacheList] { + return (0 ..< configuration.numHiddenLayers).map { _ in + CacheList(MambaCache(), KVCacheSimple()) + } } public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { @@ -716,8 +731,9 @@ public class FalconH1Model: Module, LLMModel, KVCacheDimensionProvider { } else if name.hasSuffix("down_proj.weight") { param = param * args.mlpMultipliers[1] } else if name.hasSuffix("in_proj.weight") { - let mupVector = computeMupVector(args) - param = param * (args.ssmInMultiplier * mupVector.asType(param.dtype)[0..., .newAxis]) + param = + param + * (args.ssmInMultiplier * model._mupVector.asType(param.dtype)[0..., .newAxis]) } else if name.contains("conv1d.weight") { param = param.transposed(0, 2, 1) } @@ -740,4 +756,3 @@ extension FalconH1Model: LoRAModel { model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } } } - diff --git a/Libraries/MLXLLM/Models/SSM.swift b/Libraries/MLXLLM/Models/SSM.swift index 0111a64b..f5ed5ba0 100644 --- a/Libraries/MLXLLM/Models/SSM.swift +++ b/Libraries/MLXLLM/Models/SSM.swift @@ -76,7 +76,6 @@ private final class SSMKernelManager: @unchecked Sendable { } } - func ssmUpdateKernel( hiddenStates: MLXArray, ALog: MLXArray, @@ -116,7 +115,6 @@ func ssmUpdateKernel( return (outputs[0], outputs[1]) } - public func segsum(_ x: MLXArray, mask: MLXArray? = nil) -> MLXArray { let l = x.dim(-1) var x = x @@ -141,7 +139,6 @@ public func segsum(_ x: MLXArray, mask: MLXArray? = nil) -> MLXArray { return xSegsum } - public func ssmAttn( x: MLXArray, ALog: MLXArray, @@ -209,7 +206,6 @@ public func ssmUpdate( ) -> (MLXArray, MLXArray) { let seqLen = hiddenStates.dim(1) - if seqLen == 1, let state = state, SSMKernelManager.shared.ssmKernel != nil