diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index ea8aedba..dd158557 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -50,6 +50,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable { "mimo": create(MiMoConfiguration.self, MiMoModel.init), "glm4": create(GLM4Configuration.self, GLM4Model.init), "acereason": create(Qwen2Configuration.self, Qwen2Model.init), + "falcon_h1": create(FalconH1Configuration.self, FalconH1Model.init), "bitnet": create(BitnetConfiguration.self, BitnetModel.init), "smollm3": create(SmolLM3Configuration.self, SmolLM3Model.init), "ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init), diff --git a/Libraries/MLXLLM/Models/FalconH1.swift b/Libraries/MLXLLM/Models/FalconH1.swift new file mode 100644 index 00000000..3040960e --- /dev/null +++ b/Libraries/MLXLLM/Models/FalconH1.swift @@ -0,0 +1,758 @@ +// +// FalconH1.swift +// mlx-swift-examples +// +// Created by John Mai on 2025/6/18. +// + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/falcon_h1.py + +// MARK: - Configuration + +public struct FalconH1Configuration: Codable, Sendable { + var attentionBias: Bool + var attentionDropout: Float + var attentionInMultiplier: Float + var attentionOutMultiplier: Float + var bosTokenId: Int + var embeddingMultiplier: Float + var eosTokenId: Int + var headDim: Int + var hiddenAct: String + var hiddenSize: Int + var initializerRange: Float + var intermediateSize: Int? + var keyMultiplier: Float + var lmHeadMultiplier: Float + var mambaChunkSize: Int + var mambaConvBias: Bool + var mambaDConv: Int + var mambaDHead: Int + var mambaDSSM: Int + var mambaDState: Int + var mambaExpand: Int + var mambaNGroups: Int + var mambaNHeads: Int + var mambaNormBeforeGate: Bool + var mambaProjBias: Bool + var mambaRMSNorm: Bool + var mambaUseMLP: Bool + var maxPositionEmbeddings: Int + var mlpBias: Bool + var mlpExpansionFactor: Int + var mlpMultipliers: [Float] + var modelType: String + var numAttentionHeads: Int + var numHiddenLayers: Int + var numKeyValueHeads: Int + var numLogitsToKeep: Int + var padTokenId: Int + var projectorsBias: Bool + var rmsNormEps: Float + var ropeTraditional: Bool + var ropeScaling: Float? + var ropeTheta: Float + var ssmInMultiplier: Float + var ssmMultipliers: [Float] + var ssmOutMultiplier: Float + var tieWordEmbeddings: Bool + var torchDtype: String + var vocabSize: Int + + enum CodingKeys: String, CodingKey { + case attentionBias = "attention_bias" + case attentionDropout = "attention_dropout" + case attentionInMultiplier = "attention_in_multiplier" + case attentionOutMultiplier = "attention_out_multiplier" + case bosTokenId = "bos_token_id" + case embeddingMultiplier = "embedding_multiplier" + case eosTokenId = "eos_token_id" + case headDim = "head_dim" + case hiddenAct = "hidden_act" + case hiddenSize = "hidden_size" + case initializerRange = "initializer_range" + case intermediateSize = "intermediate_size" + case keyMultiplier = "key_multiplier" + case lmHeadMultiplier = "lm_head_multiplier" + case mambaChunkSize = "mamba_chunk_size" + case mambaConvBias = "mamba_conv_bias" + case mambaDConv = "mamba_d_conv" + case mambaDHead = "mamba_d_head" + case mambaDSSM = "mamba_d_ssm" + case mambaDState = "mamba_d_state" + case mambaExpand = "mamba_expand" + case mambaNGroups = "mamba_n_groups" + case mambaNHeads = "mamba_n_heads" + case mambaNormBeforeGate = "mamba_norm_before_gate" + case mambaProjBias = "mamba_proj_bias" + case mambaRMSNorm = "mamba_rms_norm" + case mambaUseMLP = "mamba_use_mlp" + case maxPositionEmbeddings = "max_position_embeddings" + case mlpBias = "mlp_bias" + case mlpExpansionFactor = "mlp_expansion_factor" + case mlpMultipliers = "mlp_multipliers" + case modelType = "model_type" + case numAttentionHeads = "num_attention_heads" + case numHiddenLayers = "num_hidden_layers" + case numKeyValueHeads = "num_key_value_heads" + case numLogitsToKeep = "num_logits_to_keep" + case padTokenId = "pad_token_id" + case projectorsBias = "projectors_bias" + case rmsNormEps = "rms_norm_eps" + case ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + case ropeTheta = "rope_theta" + case ssmInMultiplier = "ssm_in_multiplier" + case ssmMultipliers = "ssm_multipliers" + case ssmOutMultiplier = "ssm_out_multiplier" + case tieWordEmbeddings = "tie_word_embeddings" + case torchDtype = "torch_dtype" + case vocabSize = "vocab_size" + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.attentionBias = + try container.decodeIfPresent(Bool.self, forKey: .attentionBias) ?? false + self.attentionDropout = + try container.decodeIfPresent(Float.self, forKey: .attentionDropout) ?? 0.0 + self.attentionInMultiplier = + try container.decodeIfPresent(Float.self, forKey: .attentionInMultiplier) ?? 1.0 + self.attentionOutMultiplier = + try container.decodeIfPresent(Float.self, forKey: .attentionOutMultiplier) ?? 1.0 + self.bosTokenId = try container.decodeIfPresent(Int.self, forKey: .bosTokenId) ?? 1 + self.embeddingMultiplier = + try container.decodeIfPresent(Float.self, forKey: .embeddingMultiplier) ?? 1.0 + self.eosTokenId = try container.decodeIfPresent(Int.self, forKey: .eosTokenId) ?? 2 + self.headDim = try container.decodeIfPresent(Int.self, forKey: .headDim) ?? 64 + self.hiddenAct = try container.decodeIfPresent(String.self, forKey: .hiddenAct) ?? "silu" + self.hiddenSize = try container.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 4096 + self.initializerRange = + try container.decodeIfPresent(Float.self, forKey: .initializerRange) ?? 0.02 + self.intermediateSize = + try container.decodeIfPresent(Int.self, forKey: .intermediateSize) ?? nil + self.keyMultiplier = + try container.decodeIfPresent(Float.self, forKey: .keyMultiplier) ?? 1.0 + self.lmHeadMultiplier = + try container.decodeIfPresent(Float.self, forKey: .lmHeadMultiplier) ?? 1.0 + self.mambaChunkSize = + try container.decodeIfPresent(Int.self, forKey: .mambaChunkSize) ?? 256 + self.mambaConvBias = + try container.decodeIfPresent(Bool.self, forKey: .mambaConvBias) ?? true + self.mambaDConv = try container.decodeIfPresent(Int.self, forKey: .mambaDConv) ?? 4 + self.mambaDHead = try container.decodeIfPresent(Int.self, forKey: .mambaDHead) ?? 64 + self.mambaDSSM = try container.decodeIfPresent(Int.self, forKey: .mambaDSSM) ?? 1536 + self.mambaDState = try container.decodeIfPresent(Int.self, forKey: .mambaDState) ?? 256 + self.mambaExpand = try container.decodeIfPresent(Int.self, forKey: .mambaExpand) ?? 2 + self.mambaNGroups = try container.decodeIfPresent(Int.self, forKey: .mambaNGroups) ?? 1 + self.mambaNHeads = try container.decodeIfPresent(Int.self, forKey: .mambaNHeads) ?? 128 + self.mambaNormBeforeGate = + try container.decodeIfPresent(Bool.self, forKey: .mambaNormBeforeGate) ?? true + self.mambaProjBias = + try container.decodeIfPresent(Bool.self, forKey: .mambaProjBias) ?? false + self.mambaRMSNorm = try container.decodeIfPresent(Bool.self, forKey: .mambaRMSNorm) ?? false + self.mambaUseMLP = try container.decodeIfPresent(Bool.self, forKey: .mambaUseMLP) ?? true + self.maxPositionEmbeddings = + try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 8192 + self.mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) ?? false + self.mlpExpansionFactor = + try container.decodeIfPresent(Int.self, forKey: .mlpExpansionFactor) ?? 8 + self.mlpMultipliers = + try container.decodeIfPresent([Float].self, forKey: .mlpMultipliers) ?? [1.0, 1.0] + self.modelType = + try container.decodeIfPresent(String.self, forKey: .modelType) ?? "falcon_h1" + self.numAttentionHeads = + try container.decodeIfPresent(Int.self, forKey: .numAttentionHeads) ?? 32 + self.numHiddenLayers = + try container.decodeIfPresent(Int.self, forKey: .numHiddenLayers) ?? 32 + self.numKeyValueHeads = + try container.decodeIfPresent(Int.self, forKey: .numKeyValueHeads) ?? 8 + self.numLogitsToKeep = + try container.decodeIfPresent(Int.self, forKey: .numLogitsToKeep) ?? 1 + self.padTokenId = try container.decodeIfPresent(Int.self, forKey: .padTokenId) ?? 0 + self.projectorsBias = + try container.decodeIfPresent(Bool.self, forKey: .projectorsBias) ?? false + self.rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1e-5 + self.ropeTraditional = + try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false + self.ropeScaling = try container.decodeIfPresent(Float?.self, forKey: .ropeScaling) ?? nil + self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 100000.0 + self.ssmInMultiplier = + try container.decodeIfPresent(Float.self, forKey: .ssmInMultiplier) ?? 1.0 + self.ssmMultipliers = + try container.decodeIfPresent([Float].self, forKey: .ssmMultipliers) ?? [ + 1.0, 1.0, 1.0, 1.0, 1.0, + ] + self.ssmOutMultiplier = + try container.decodeIfPresent(Float.self, forKey: .ssmOutMultiplier) ?? 1.0 + self.tieWordEmbeddings = + try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false + self.torchDtype = + try container.decodeIfPresent(String.self, forKey: .torchDtype) ?? "bfloat16" + self.vocabSize = try container.decodeIfPresent(Int.self, forKey: .vocabSize) ?? 128000 + } +} + +// MARK: - RMSNormGated + +private class RMSNormGated: Module { + let weight: MLXArray + let varianceEpsilon: Float + let nGroups: Int + let normBeforeGate: Bool + + init(hiddenSize: Int, eps: Float = 1e-6, nGroups: Int = 1, normBeforeGate: Bool = true) { + self.weight = MLXArray.ones([hiddenSize]) + self.varianceEpsilon = eps + self.nGroups = nGroups + self.normBeforeGate = normBeforeGate + } + + func callAsFunction(_ hiddenStates: MLXArray, gate: MLXArray? = nil) -> MLXArray { + var hiddenStates = hiddenStates + + if !normBeforeGate, let gate { + hiddenStates = hiddenStates * silu(gate) + } + + hiddenStates = MLXFast.rmsNorm(hiddenStates, weight: weight, eps: varianceEpsilon) + + if normBeforeGate, let gate { + hiddenStates = hiddenStates * silu(gate) + } + + return hiddenStates + } +} + +private func computeMupVector(_ args: FalconH1Configuration) -> MLXArray { + let intermediateSize = args.mambaDSSM + let groupsTimeStateSize = args.mambaNGroups * args.mambaDState + let numHeads = args.mambaNHeads + + let sizes = [ + intermediateSize, + intermediateSize, + groupsTimeStateSize, + groupsTimeStateSize, + numHeads, + ] + + let segments = zip(sizes, args.ssmMultipliers).map { size, multiplier in + MLX.broadcast(MLXArray(multiplier), to: [size]) + } + + return concatenated(segments) +} + +// MARK: - Attention + +private class Attention: Module { + let hiddenSize: Int + let numHeads: Int + let numKVHeads: Int + let headDim: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var qProj: Linear + @ModuleInfo(key: "k_proj") var kProj: Linear + @ModuleInfo(key: "v_proj") var vProj: Linear + @ModuleInfo(key: "o_proj") var oProj: Linear + + let rope: RoPE + + init(_ args: FalconH1Configuration) { + self.hiddenSize = args.hiddenSize + self.numHeads = args.numAttentionHeads + self.numKVHeads = args.numKeyValueHeads + self.headDim = args.headDim + self.scale = pow(Float(headDim), -0.5) + + _qProj.wrappedValue = Linear(hiddenSize, numHeads * headDim, bias: args.attentionBias) + _kProj.wrappedValue = Linear(hiddenSize, numKVHeads * headDim, bias: args.attentionBias) + _vProj.wrappedValue = Linear(hiddenSize, numKVHeads * headDim, bias: args.attentionBias) + _oProj.wrappedValue = Linear(numHeads * headDim, hiddenSize, bias: args.attentionBias) + + let ropeScale: Float = + if let ropeScaling = args.ropeScaling { + 1 / ropeScaling + } else { + 1 + } + + self.rope = RoPE( + dimensions: headDim, + traditional: args.ropeTraditional, + base: args.ropeTheta, + scale: ropeScale + ) + } + + func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? = nil) -> MLXArray { + let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2)) + + var queries = qProj(x) + var keys = kProj(x) + var values = vProj(x) + + queries = queries.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, numKVHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, numKVHeads, -1).transposed(0, 2, 1, 3) + + if let cache { + queries = rope(queries, offset: cache.offset) + keys = rope(keys, offset: cache.offset) + (keys, values) = cache.update(keys: keys, values: values) + } else { + queries = rope(queries) + keys = rope(keys) + } + + var output = MLXFast.scaledDotProductAttention( + queries: queries, + keys: keys, + values: values, + scale: scale, + mask: mask + ) + + output = output.transposed(0, 2, 1, 3).reshaped(B, L, -1) + return oProj(output) + } +} + +// MARK: - Mixer + +private class Mixer: Module { + let numHeads: Int + let hiddenSize: Int + let ssmStateSize: Int + let convKernelSize: Int + let intermediateSize: Int + let useConvBias: Bool + let useBias: Bool + let layerNormEpsilon: Float + let groupsTimeStateSize: Int + let nGroups: Int + let headDim: Int + let chunkSize: Int + let timeStepLimit: (Float, Float) + let timeStepMin: Float + let timeStepMax: Float + let convDim: Int + let mambaRMSNorm: Bool + var norm: RMSNormGated? = nil + let ssmInMultiplier: Float + let conv1d: Conv1d + + @ModuleInfo(key: "in_proj") var inProj: Linear + @ParameterInfo(key: "dt_bias") var dtBias: MLXArray + @ParameterInfo(key: "A_log") var aLog: MLXArray + @ParameterInfo(key: "D") var d: MLXArray + @ModuleInfo(key: "out_proj") var outProj: Linear + + init(_ args: FalconH1Configuration) { + self.numHeads = args.mambaNHeads + self.hiddenSize = args.hiddenSize + self.ssmStateSize = args.mambaDState + self.convKernelSize = args.mambaDConv + self.intermediateSize = args.mambaDSSM + self.useConvBias = args.mambaConvBias + self.useBias = args.mambaProjBias + self.layerNormEpsilon = args.rmsNormEps + self.groupsTimeStateSize = args.mambaNGroups * args.mambaDState + self.nGroups = args.mambaNGroups + self.headDim = args.mambaDHead + self.chunkSize = args.mambaChunkSize + self.timeStepLimit = (0.0, Float.infinity) + self.timeStepMin = 0.001 + self.timeStepMax = 0.1 + + self.convDim = intermediateSize + 2 * nGroups * ssmStateSize + + self.conv1d = Conv1d( + inputChannels: convDim, + outputChannels: convDim, + kernelSize: convKernelSize, + groups: convDim, + bias: useConvBias + ) + + let projectionSize = intermediateSize + convDim + numHeads + _inProj.wrappedValue = Linear( + hiddenSize, + projectionSize, + bias: args.mambaProjBias + ) + + _dtBias.wrappedValue = MLXArray.ones([numHeads]) + + let A = MLXArray(Array(1 ..< numHeads + 1)) + + _aLog.wrappedValue = log(A) + + self.mambaRMSNorm = args.mambaRMSNorm + if mambaRMSNorm { + self.norm = RMSNormGated( + hiddenSize: intermediateSize, + eps: layerNormEpsilon, + nGroups: nGroups, + normBeforeGate: args.mambaNormBeforeGate + ) + } + + _d.wrappedValue = MLXArray.ones([numHeads]) + + _outProj.wrappedValue = Linear( + intermediateSize, + hiddenSize, + bias: args.projectorsBias + ) + + self.ssmInMultiplier = args.ssmInMultiplier + } + + private func _applyConv(_ convInput: MLXArray, cache: MambaCache?) -> MLXArray { + let convState: MLXArray + if cache == nil || cache?[0] == nil { + convState = MLXArray.zeros( + [convInput.dim(0), convKernelSize - 1, convDim], + dtype: convInput.dtype + ) + } else { + convState = cache![0]! + } + + let paddedInput = concatenated([convState, convInput], axis: 1) + + if let cache = cache { + cache[0] = paddedInput[0..., (-(convKernelSize - 1))...] + } + + let convOutput = conv1d(paddedInput) + return silu(convOutput) + } + + private func _ssm( + hiddenStates: MLXArray, + B: MLXArray, + C: MLXArray, + dt: MLXArray, + state: MLXArray? = nil, + mask: MLXArray? = nil + ) -> (MLXArray, MLXArray) { + let (batchSize, seqLen, _) = (hiddenStates.dim(0), hiddenStates.dim(1), hiddenStates.dim(2)) + + let hiddenStates = hiddenStates.reshaped(batchSize, seqLen, numHeads, headDim) + let B = B.reshaped(batchSize, seqLen, nGroups, ssmStateSize) + let C = C.reshaped(batchSize, seqLen, nGroups, ssmStateSize) + + let (y, newState) = ssmUpdate( + hiddenStates: hiddenStates, + ALog: aLog, + B: B, + C: C, + D: d, + dt: dt, + dtBias: dtBias, + state: state, + timeStepLimit: timeStepLimit, + mask: mask + ) + + return (y.reshaped(batchSize, seqLen, intermediateSize), newState) + } + + func callAsFunction( + _ inputStates: MLXArray, cache: MambaCache? = nil, mask: MLXArray? = nil + ) -> MLXArray { + let projectedStates = inProj(inputStates) + + let splits = MLX.split( + projectedStates, + indices: [intermediateSize, intermediateSize + convDim], + axis: -1 + ) + let gate = splits[0] + var convInput = splits[1] + let dt = splits[2] + + if let mask = mask { + convInput = which(mask[.ellipsis, .newAxis], convInput, 0) + } + let convOutput = _applyConv(convInput, cache: cache) + + let convSplits = MLX.split( + convOutput, + indices: [ + intermediateSize, + intermediateSize + nGroups * ssmStateSize, + ], + axis: -1 + ) + let hiddenStatesSSM = convSplits[0] + let B = convSplits[1] + let C = convSplits[2] + + var state = cache?[1] + var y: MLXArray + (y, state) = _ssm( + hiddenStates: hiddenStatesSSM, + B: B, + C: C, + dt: dt, + state: state, + mask: mask + ) + if let cache = cache { + cache[1] = state + } + + if let norm = norm { + y = norm(y, gate: gate) + } else { + y = y * silu(gate) + } + + return outProj(y) + } +} + +// MARK: - MLP + +private class MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gateProj: Linear + @ModuleInfo(key: "up_proj") var upProj: Linear + @ModuleInfo(key: "down_proj") var downProj: Linear + + let gateMultiplier: Float + let downMultiplier: Float + + init(_ args: FalconH1Configuration) { + let hiddenSize = args.hiddenSize + let intermediateSize = args.intermediateSize ?? 4 * hiddenSize + + _gateProj.wrappedValue = Linear(hiddenSize, intermediateSize, bias: args.mlpBias) + _upProj.wrappedValue = Linear(hiddenSize, intermediateSize, bias: args.mlpBias) + _downProj.wrappedValue = Linear(intermediateSize, hiddenSize, bias: args.mlpBias) + + self.gateMultiplier = args.mlpMultipliers[0] + self.downMultiplier = args.mlpMultipliers[1] + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let y = upProj(x) * silu(gateProj(x)) + return downProj(y) + } +} + +// MARK: - DecoderLayer + +private class DecoderLayer: Module { + @ModuleInfo(key: "feed_forward") var feedForward: MLP + @ModuleInfo(key: "mamba") var mamba: Mixer + @ModuleInfo(key: "self_attn") var attention: Attention + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "pre_ff_layernorm") var preFfLayerNorm: RMSNorm + + let channelsAttn: Int + + init(_ args: FalconH1Configuration) { + let headDim = args.headDim + self.channelsAttn = args.numAttentionHeads * headDim + 2 * args.numKeyValueHeads * headDim + + _feedForward.wrappedValue = MLP(args) + _mamba.wrappedValue = Mixer(args) + _attention.wrappedValue = Attention(args) + _inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps + ) + _preFfLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps + ) + } + + func callAsFunction( + _ h: MLXArray, + cache: CacheList?, + attnMask: MLXArray?, + mambaMask: MLXArray? + ) -> MLXArray { + var residual = h + var h = inputLayerNorm(h) + + let mambaH = mamba(h, cache: cache?[0] as? MambaCache, mask: mambaMask) + + let attnH = attention( + h, + mask: attnMask, + cache: cache?[1] + ) + + h = residual + mambaH + attnH + + residual = h + h = preFfLayerNorm(h) + h = feedForward(h) + return residual + h + } +} + +// MARK: - Helper Functions + +private func createSSMMask(h: MLXArray, cache: ArraysCache?) -> MLXArray? { + if let cache = cache { + return cache.makeMask(N: h.dim(1)) + } + return nil +} + +private func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? { + let N = h.dim(1) + // If cache exists and can make masks, use it + // Otherwise for single token, no mask needed + // For multi-token, SDPA will handle causal mask internally when nil + if N == 1 { + return nil + } + return nil // Will be handled by SDPA internally when nil +} + +// MARK: - Model + +private class ModelInner: Module { + let args: FalconH1Configuration + let vocabSize: Int + let hiddenSize: Int + + let _mupVector: MLXArray + let layers: [DecoderLayer] + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + @ModuleInfo(key: "final_layernorm") var finalLayerNorm: RMSNorm + + init(_ args: FalconH1Configuration) { + self.args = args + self.vocabSize = args.vocabSize + self.hiddenSize = args.hiddenSize + + _embedTokens.wrappedValue = Embedding(embeddingCount: vocabSize, dimensions: hiddenSize) + + self._mupVector = computeMupVector(args) + self.layers = (0 ..< args.numHiddenLayers).map { _ in + DecoderLayer(args) + } + + _finalLayerNorm.wrappedValue = RMSNorm(dimensions: hiddenSize, eps: args.rmsNormEps) + } + + func callAsFunction(_ inputs: MLXArray, mask: MLXArray? = nil, cache: [CacheList]? = nil) + -> MLXArray + { + var h = embedTokens(inputs) + + let cache: [CacheList?] = cache ?? Array(repeating: nil, count: layers.count) + + let mambaMask = createSSMMask(h: h, cache: cache[0]?[0] as? MambaCache) + let attnMask: MLXArray? = createAttentionMask( + h: h, cache: cache[0]?[1] != nil ? [cache[0]![1]] : nil) + + for (layer, c) in zip(layers, cache) { + h = layer( + h, + cache: c, + attnMask: attnMask, + mambaMask: mambaMask + ) + } + + return finalLayerNorm(h) + } +} + +public class FalconH1Model: Module, LLMModel, KVCacheDimensionProvider { + public let vocabularySize: Int + public let kvHeads: [Int] + + private let model: ModelInner + let configuration: FalconH1Configuration + + @ModuleInfo(key: "lm_head") var lmHead: Linear + + public init(_ args: FalconH1Configuration) { + self.configuration = args + self.vocabularySize = args.vocabSize + self.kvHeads = (0 ..< args.numKeyValueHeads).map { _ in args.numHiddenLayers } + self.model = ModelInner(args) + + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabSize, bias: false) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { + let out = model(inputs, cache: cache as? [CacheList]) + return lmHead(out) + } + + public func makeCache() -> [CacheList] { + return (0 ..< configuration.numHiddenLayers).map { _ in + CacheList(MambaCache(), KVCacheSimple()) + } + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + let c1d = weights["model.layers.0.mamba.conv1d.weight"]! + if c1d.dim(-1) <= c1d.dim(1) { + return weights + } + + var sanitizedWeights = [String: MLXArray]() + let args = configuration + + for (name, var param) in weights { + if name.hasSuffix("embed_tokens.weight") { + param = param * args.embeddingMultiplier + } else if name.hasSuffix("lm_head.weight") { + param = param * args.lmHeadMultiplier + } else if name.hasSuffix("q_proj.weight") || name.hasSuffix("k_proj.weight") { + param = param * args.attentionInMultiplier + } else if name.hasSuffix("key_proj.weight") { + param = param * args.attentionInMultiplier * args.keyMultiplier + } else if name.hasSuffix("o_proj.weight") { + param = param * args.attentionOutMultiplier + } else if name.hasSuffix("out_proj.weight") { + param = param * args.ssmOutMultiplier + } else if name.hasSuffix("gate_proj.weight") { + param = param * args.mlpMultipliers[0] + } else if name.hasSuffix("down_proj.weight") { + param = param * args.mlpMultipliers[1] + } else if name.hasSuffix("in_proj.weight") { + param = + param + * (args.ssmInMultiplier * model._mupVector.asType(param.dtype)[0..., .newAxis]) + } else if name.contains("conv1d.weight") { + param = param.transposed(0, 2, 1) + } + + sanitizedWeights[name] = param + } + + return sanitizedWeights + } + + public func newCache(parameters: GenerateParameters?) -> [any KVCache] { + model.layers.map { _ in CacheList(MambaCache(), KVCacheSimple()) } + } +} + +// MARK: - LoRA + +extension FalconH1Model: LoRAModel { + public func loraLinearLayers() -> LoRALinearLayers { + model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } +} diff --git a/Libraries/MLXLLM/Models/SSM.swift b/Libraries/MLXLLM/Models/SSM.swift new file mode 100644 index 00000000..f5ed5ba0 --- /dev/null +++ b/Libraries/MLXLLM/Models/SSM.swift @@ -0,0 +1,238 @@ +// +// SSM.swift +// mlx-swift-examples +// +// Created by John Mai on 2025/10/01. +// + +import Foundation +import MLX +import MLXFast +import MLXNN + +public func computeDt(_ dt: MLXArray, _ dtBias: MLXArray, _ timeStepLimit: (Float, Float)) + -> MLXArray +{ + let dt = softplus(dt + dtBias) + return MLX.clip(dt, min: timeStepLimit.0, max: timeStepLimit.1) +} + +private func makeSSMKernel() -> MLXFast.MLXFastKernel? { + let source = """ + auto n = thread_position_in_grid.z; + auto h_idx = n % H; + auto g_idx = n / G; + constexpr int n_per_t = Ds / 32; + + auto x = X + n * Dh; + out += n * Dh; + auto i_state = state_in + n * Dh * Ds; + auto o_state = state_out + n * Dh * Ds; + + // C and B have shape [batch, group, state_dim] + // C and B need to be offset by group size + auto C_ = C + g_idx * Ds; + auto B_ = B + g_idx * Ds; + + auto ds_idx = thread_position_in_threadgroup.x; + auto d_idx = thread_position_in_grid.y; + + auto dt_ = static_cast(dt[n]); + auto A = -fast::exp(static_cast(A_log[h_idx])); + auto dA = fast::exp(A * dt_); + + float acc = 0.0; + auto x_ = static_cast(x[d_idx]); + + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * ds_idx + i; + auto idx = d_idx * Ds + s_idx; + auto dB_by_x = x_ * dt_ * static_cast(B_[s_idx]); + auto state = dA * i_state[idx] + dB_by_x; + o_state[idx] = static_cast(state); + acc += state * C_[s_idx]; + } + acc = simd_sum(acc); + if (thread_index_in_simdgroup == 0) { + out[d_idx] = static_cast(acc + x_ * D[h_idx]); + } + """ + + return MLXFast.metalKernel( + name: "ssm_kernel", + inputNames: ["X", "A_log", "B", "C", "D", "dt", "state_in"], + outputNames: ["out", "state_out"], + source: source + ) +} + +private final class SSMKernelManager: @unchecked Sendable { + static let shared = SSMKernelManager() + + let ssmKernel: MLXFast.MLXFastKernel? + + private init() { + ssmKernel = makeSSMKernel() + } +} + +func ssmUpdateKernel( + hiddenStates: MLXArray, + ALog: MLXArray, + B: MLXArray, + C: MLXArray, + D: MLXArray, + dt: MLXArray, + dtBias: MLXArray, + state: MLXArray, + timeStepLimit: (Float, Float) +) -> (MLXArray, MLXArray) { + let (n, _, h, d) = hiddenStates.shape4 + let inputType = hiddenStates.dtype + let (hb, ds) = (B.dim(-2), B.dim(-1)) + + let dt = computeDt(dt, dtBias, timeStepLimit) + + guard let kernel = SSMKernelManager.shared.ssmKernel else { + fatalError("SSM kernel not available") + } + + let outputs = kernel( + [hiddenStates, ALog, B, C, D, dt, state], + template: [ + ("T", inputType), + ("Dh", d), + ("Ds", ds), + ("H", h), + ("G", h / hb), + ], + grid: (32, d, h * n), + threadGroup: (32, 8, 1), + outputShapes: [[n, 1, h, d], state.shape], + outputDTypes: [inputType, inputType] + ) + + return (outputs[0], outputs[1]) +} + +public func segsum(_ x: MLXArray, mask: MLXArray? = nil) -> MLXArray { + let l = x.dim(-1) + var x = x + + if let mask = mask { + let mask = MLX.expandedDimensions(mask, axis: 1) + x = x * mask + } + + x = MLX.repeated(x[.ellipsis, .newAxis], count: l, axis: -1) + x = MLX.tril(x, k: -1) + var xSegsum = MLX.cumsum(x, axis: -2) + + if let mask = mask { + xSegsum = which( + mask[.ellipsis, .newAxis, 0...] * mask[.ellipsis, .newAxis], + xSegsum, + MLXArray(-Float.infinity) + ) + } + + return xSegsum +} + +public func ssmAttn( + x: MLXArray, + ALog: MLXArray, + B: MLXArray, + C: MLXArray, + D: MLXArray, + dt: MLXArray, + dtBias: MLXArray, + state: MLXArray? = nil, + timeStepLimit: (Float, Float) = (0.001, 100.0), + mask: MLXArray? = nil +) -> (MLXArray, MLXArray) { + let (b, l, h, dh) = x.shape4 + let (_, _, g, d) = B.shape4 + + let dt = computeDt(dt, dtBias, timeStepLimit) + let repeats = h / g + let A = -MLX.exp(ALog) + var B = MLX.transposed(B, axes: [0, 2, 3, 1]) + + // A * s + B * C + var CB = MLX.swappedAxes(C, 1, 2).matmul(B) + CB = MLX.repeated(CB, count: repeats, axis: 1) + + let dtA = dt * A.reshaped(1, 1, -1) + var decay = MLX.exp(segsum(dtA.swappedAxes(1, 2), mask: mask)) + + let surrogateAttentionMatrix = MLX.tril(CB * decay, k: 0) + + let dtx = dt.reshaped(b, l, h, 1) * x + var y = surrogateAttentionMatrix.matmul(dtx.swappedAxes(1, 2)) + y = MLX.swappedAxes(y, 1, 2) + + decay = decay[0..., 0..., (-1)..., 0...].transposed(0, 3, 1, 2) + B = MLX.repeated(B, count: h / g, axis: 1).swappedAxes(2, 3) + var dtxdecay = dtx * decay + dtxdecay = dtxdecay.swappedAxes(1, 2).swappedAxes(2, 3) + + var nextState = dtxdecay.matmul(B) + + if var state = state { + let expDtACumsum = MLX.exp(MLX.cumsum(dtA, axis: -2)) + nextState = nextState + expDtACumsum[0..., -1, 0..., .newAxis, .newAxis] * state + state = state.reshaped(b, 1, g, repeats, dh, d) + let C = C.reshaped(b, l, g, 1, d, 1) + let yPrev = (state.matmul(C)).squeezed(axis: -1).flattened(start: 2, end: 3) + y = y + expDtACumsum[.ellipsis, .newAxis] * yPrev + } + + y = y + x * D.reshaped(1, 1, h, 1) + return (y, nextState) +} + +public func ssmUpdate( + hiddenStates: MLXArray, + ALog: MLXArray, + B: MLXArray, + C: MLXArray, + D: MLXArray, + dt: MLXArray, + dtBias: MLXArray, + state: MLXArray? = nil, + timeStepLimit: (Float, Float) = (0.001, 100.0), + mask: MLXArray? = nil +) -> (MLXArray, MLXArray) { + let seqLen = hiddenStates.dim(1) + + if seqLen == 1, + let state = state, + SSMKernelManager.shared.ssmKernel != nil + { + return ssmUpdateKernel( + hiddenStates: hiddenStates, + ALog: ALog, + B: B, + C: C, + D: D, + dt: dt, + dtBias: dtBias, + state: state, + timeStepLimit: timeStepLimit + ) + } else { + return ssmAttn( + x: hiddenStates, + ALog: ALog, + B: B, + C: C, + D: D, + dt: dt, + dtBias: dtBias, + state: state, + timeStepLimit: timeStepLimit, + mask: mask + ) + } +} diff --git a/Libraries/MLXLMCommon/BaseConfiguration.swift b/Libraries/MLXLMCommon/BaseConfiguration.swift index e9c0ed18..83083f3d 100644 --- a/Libraries/MLXLMCommon/BaseConfiguration.swift +++ b/Libraries/MLXLMCommon/BaseConfiguration.swift @@ -21,6 +21,7 @@ public struct BaseConfiguration: Codable, Sendable { public var quantMethod: String? = nil public var linearClass: String? = nil public var quantizationMode: String? = nil + public var mode: String? = nil public var asTuple: (Int, Int) { (groupSize, bits) } @@ -30,6 +31,7 @@ public struct BaseConfiguration: Codable, Sendable { case quantMethod = "quant_method" case linearClass = "linear_class" case quantizationMode = "quantization_mode" + case mode = "mode" } } @@ -116,6 +118,7 @@ public struct BaseConfiguration: Codable, Sendable { case Quantization.CodingKeys.quantMethod.rawValue: continue case Quantization.CodingKeys.linearClass.rawValue: continue case Quantization.CodingKeys.quantizationMode.rawValue: continue + case Quantization.CodingKeys.mode.rawValue: continue default: if let f = try? container.decode(Bool.self, forKey: key) { diff --git a/Libraries/MLXLMCommon/KVCache.swift b/Libraries/MLXLMCommon/KVCache.swift index 75559c8f..d3fe3eb8 100644 --- a/Libraries/MLXLMCommon/KVCache.swift +++ b/Libraries/MLXLMCommon/KVCache.swift @@ -206,6 +206,13 @@ public func createAttentionMask(h: MLXArray, cache: [KVCache]?, returnArray: Boo return .none } +public func createSSMMask(h: MLXArray, cache: MambaCache?) -> MLXArray? { + if let cache { + return cache.makeMask(N: h.dim(1)) + } + return nil +} + /// Standard KV cache implementation based on Python's KVCache /// See https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/base.py#L11 public class KVCacheSimple: BaseKVCache, CustomDebugStringConvertible { @@ -887,11 +894,14 @@ public class ChunkedKVCache: KVCacheSimple { } } -/// Simple cache for Mamba-style state space models -public class MambaCache: BaseKVCache { - private var cache: [MLXArray?] = [nil, nil] +/// Base cache for array-based state storage +public class ArraysCache: BaseKVCache { + private var cache: [MLXArray?] + private var leftPadding: MLXArray? - public override init() { + public init(size: Int, leftPadding: [Int]? = nil) { + self.cache = Array(repeating: nil, count: size) + self.leftPadding = leftPadding.map { MLXArray($0) } super.init() } @@ -904,39 +914,48 @@ public class MambaCache: BaseKVCache { set { cache[index] = newValue } } - public override func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { - // Mamba doesn't use traditional KV cache update pattern - fatalError("MambaCache should not use update(keys:values:) - use subscript access instead") - } - public override var state: [MLXArray] { get { - // Need to preserve the structure including nils, similar to Python version - // Use empty arrays as placeholders for nil values - var result: [MLXArray] = [] - for item in cache { - if let array = item { - result.append(array) - } else { - // Use an empty array as placeholder for nil (this shape should never occur naturally) - result.append(MLXArray.zeros([0], dtype: .float32)) - } - } - return result + return cache.compactMap { $0 } } set { - guard newValue.count == cache.count else { - fatalError("MambaCache state must have exactly \(cache.count) elements") - } - for (i, array) in newValue.enumerated() { - // Check if this is our nil placeholder (empty array with size 0) - if array.size == 0 { - cache[i] = nil - } else { - cache[i] = array - } + cache = newValue.map { $0 as MLXArray? } + } + } + + /// In-place filter to keep just the given indices in the cache + public func filter(batchIndices: MLXArray) { + cache = cache.map { c in + c?[batchIndices] + } + leftPadding = nil + } + + /// In-place extend this cache with the other cache + public func extend(other: ArraysCache) { + cache = zip(cache, other.cache).map { (c, o) in + if let c = c, let o = o { + return MLX.concatenated([c, o]) } + return c ?? o } + leftPadding = nil + } + + /// Create attention mask based on left padding + public func makeMask(N: Int) -> MLXArray? { + if cache[0] == nil, let leftPadding = leftPadding { + return MLXArray(0 ..< N) .>= leftPadding[0..., .newAxis] + } else { + return nil + } + } +} + +/// Simple cache for Mamba-style state space models +public class MambaCache: ArraysCache { + public init(leftPadding: [Int]? = nil) { + super.init(size: 2, leftPadding: leftPadding) } }