|  | 
| 52 | 52 |    "outputs": [], | 
| 53 | 53 |    "source": [ | 
| 54 | 54 |     "%%writefile handler/code/requirements.txt\n", | 
| 55 |  | -    "sentence-transformers==5.0.0" | 
|  | 55 | +    "transformers==4.56.1\n", | 
|  | 56 | +    "huggingface_hub==0.35.0\n", | 
|  | 57 | +    "hf_xet==1.1.10\n", | 
|  | 58 | +    "tokenizers==0.22.0\n", | 
|  | 59 | +    "regex==2025.9.1\n", | 
|  | 60 | +    "safetensors==0.6.2\n", | 
|  | 61 | +    "sentence-transformers==5.1.0" | 
| 56 | 62 |    ] | 
| 57 | 63 |   }, | 
| 58 | 64 |   { | 
|  | 
| 134 | 140 |     "        )\n", | 
| 135 | 141 |     "        print(f\"Using device: {self.device}\")\n", | 
| 136 | 142 |     "        self.model = SparseEncoder(model_id, device=self.device, trust_remote_code=trust_remote_code)\n", | 
|  | 143 | +    "        self._warmup()\n", | 
| 137 | 144 |     "        self.initialized = True\n", | 
| 138 | 145 |     "\n", | 
| 139 |  | -    "    def preprocess(self, requests):\n", | 
|  | 146 | +    "    def _warmup(self):\n", | 
|  | 147 | +    "        input_data = [{\"body\": [\"hello world\"] * 10}]\n", | 
|  | 148 | +    "        self.handle(input_data, None)\n", | 
|  | 149 | +    "\n", | 
|  | 150 | +    "    def _preprocess(self, requests):\n", | 
| 140 | 151 |     "        inputSentence = []\n", | 
| 141 | 152 |     "        batch_idx = []\n", | 
|  | 153 | +    "        formats = []  # per-text format: \"word\" or \"token_id\"\n", | 
| 142 | 154 |     "\n", | 
| 143 | 155 |     "        for request in requests:\n", | 
| 144 | 156 |     "            request_body = request.get(\"body\")\n", | 
| 145 | 157 |     "            if isinstance(request_body, bytearray):\n", | 
| 146 | 158 |     "                request_body = request_body.decode(\"utf-8\")\n", | 
| 147 | 159 |     "                request_body = json.loads((request_body))\n", | 
| 148 |  | -    "            if isinstance(request_body, list):\n", | 
|  | 160 | +    "\n", | 
|  | 161 | +    "            # dict-based new schema: {\"texts\": str | list[str], \"sparse_embedding_format\": str}\n", | 
|  | 162 | +    "            if isinstance(request_body, dict):\n", | 
|  | 163 | +    "                texts = request_body.get(\"texts\")\n", | 
|  | 164 | +    "                fmt = request_body.get(\"sparse_embedding_format\", \"word\")\n", | 
|  | 165 | +    "                fmt = \"token_id\" if isinstance(fmt, str) and fmt.lower() == \"token_id\" else \"word\"\n", | 
|  | 166 | +    "\n", | 
|  | 167 | +    "                if isinstance(texts, list):\n", | 
|  | 168 | +    "                    inputSentence += texts\n", | 
|  | 169 | +    "                    batch_idx.append(len(texts))\n", | 
|  | 170 | +    "                    formats += [fmt] * len(texts)\n", | 
|  | 171 | +    "                else:\n", | 
|  | 172 | +    "                    inputSentence.append(texts)\n", | 
|  | 173 | +    "                    batch_idx.append(1)\n", | 
|  | 174 | +    "                    formats.append(fmt)\n", | 
|  | 175 | +    "\n", | 
|  | 176 | +    "            # legacy schemas\n", | 
|  | 177 | +    "            elif isinstance(request_body, list):\n", | 
| 149 | 178 |     "                inputSentence += request_body\n", | 
| 150 | 179 |     "                batch_idx.append(len(request_body))\n", | 
|  | 180 | +    "                formats += [\"word\"] * len(request_body)\n", | 
| 151 | 181 |     "            else:\n", | 
| 152 | 182 |     "                inputSentence.append(request_body)\n", | 
| 153 | 183 |     "                batch_idx.append(1)\n", | 
|  | 184 | +    "                formats.append(\"word\")\n", | 
|  | 185 | +    "\n", | 
|  | 186 | +    "        return inputSentence, batch_idx, formats\n", | 
| 154 | 187 |     "\n", | 
| 155 |  | -    "        return inputSentence, batch_idx\n", | 
|  | 188 | +    "    def _convert_token_ids(self, sparse_embedding):\n", | 
|  | 189 | +    "        token_ids = self.model.tokenizer.convert_tokens_to_ids([x[0] for x in sparse_embedding])\n", | 
|  | 190 | +    "        return [(str(token_ids[i]), sparse_embedding[i][1]) for i in range(len(token_ids))]\n", | 
| 156 | 191 |     "\n", | 
| 157 | 192 |     "    def handle(self, data, context):\n", | 
| 158 |  | -    "        inputSentence, batch_idx = self.preprocess(data)\n", | 
|  | 193 | +    "        inputSentence, batch_idx, formats = self._preprocess(data)\n", | 
| 159 | 194 |     "        model_output = self.model.encode_document(inputSentence, batch_size=max_bs)\n", | 
| 160 |  | -    "        sparse_embedding = list(map(dict,self.model.decode(model_output)))\n", | 
|  | 195 | +    "\n", | 
|  | 196 | +    "        sparse_embedding_word = self.model.decode(model_output)\n", | 
|  | 197 | +    "        for i, fmt in enumerate(formats):\n", | 
|  | 198 | +    "            if fmt == \"token_id\":\n", | 
|  | 199 | +    "                sparse_embedding_word[i] = self._convert_token_ids(sparse_embedding_word[i])\n", | 
|  | 200 | +    "        sparse_embedding = list(map(dict, sparse_embedding_word))\n", | 
| 161 | 201 |     "\n", | 
| 162 | 202 |     "        outputs = [sparse_embedding[s:e]\n", | 
| 163 | 203 |     "           for s, e in zip([0]+list(itertools.accumulate(batch_idx))[:-1],\n", | 
|  | 
| 424 | 464 |     "```json\n", | 
| 425 | 465 |     "POST /_plugins/_ml/connectors/_create\n", | 
| 426 | 466 |     "{\n", | 
| 427 |  | -    "  \"name\": \"test\",\n", | 
| 428 |  | -    "  \"description\": \"Test connector for Sagemaker model\",\n", | 
|  | 467 | +    "  \"name\": \"Sagemaker Connector: embedding\",\n", | 
|  | 468 | +    "  \"description\": \"The connector to sagemaker embedding model\",\n", | 
| 429 | 469 |     "  \"version\": 1,\n", | 
| 430 | 470 |     "  \"protocol\": \"aws_sigv4\",\n", | 
| 431 | 471 |     "  \"credential\": {\n", | 
|  | 
| 436 | 476 |     "    \"region\": \"{region}\",\n", | 
| 437 | 477 |     "    \"service_name\": \"sagemaker\",\n", | 
| 438 | 478 |     "    \"input_docs_processed_step_size\": 2,\n", | 
|  | 479 | +    "    \"sparse_embedding_format\": \"word\"\n", | 
| 439 | 480 |     "  },\n", | 
| 440 | 481 |     "  \"actions\": [\n", | 
| 441 | 482 |     "    {\n", | 
|  | 
| 445 | 486 |     "        \"content-type\": \"application/json\"\n", | 
| 446 | 487 |     "      },\n", | 
| 447 | 488 |     "      \"url\": \"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{predictor.endpoint_name}/invocations\",\n", | 
| 448 |  | -    "      \"request_body\": \"${parameters.input}\"\n", | 
|  | 489 | +    "      \"request_body\": \"\"\"\n", | 
|  | 490 | +    "          {\n", | 
|  | 491 | +    "              \"texts\": ${parameters.input},\n", | 
|  | 492 | +    "              \"sparse_embedding_format\": \"${parameters.sparse_embedding_format}\"\n", | 
|  | 493 | +    "          }\n", | 
|  | 494 | +    "      \"\"\"\n", | 
| 449 | 495 |     "    }\n", | 
| 450 | 496 |     "  ],\n", | 
| 451 | 497 |     "  \"client_config\":{\n", | 
|  | 
0 commit comments