|
52 | 52 | "outputs": [], |
53 | 53 | "source": [ |
54 | 54 | "%%writefile handler/code/requirements.txt\n", |
55 | | - "sentence-transformers==5.0.0" |
| 55 | + "transformers==4.56.1\n", |
| 56 | + "huggingface_hub==0.35.0\n", |
| 57 | + "hf_xet==1.1.10\n", |
| 58 | + "tokenizers==0.22.0\n", |
| 59 | + "regex==2025.9.1\n", |
| 60 | + "safetensors==0.6.2\n", |
| 61 | + "sentence-transformers==5.1.0" |
56 | 62 | ] |
57 | 63 | }, |
58 | 64 | { |
|
134 | 140 | " )\n", |
135 | 141 | " print(f\"Using device: {self.device}\")\n", |
136 | 142 | " self.model = SparseEncoder(model_id, device=self.device, trust_remote_code=trust_remote_code)\n", |
| 143 | + " self._warmup()\n", |
137 | 144 | " self.initialized = True\n", |
138 | 145 | "\n", |
139 | | - " def preprocess(self, requests):\n", |
| 146 | + " def _warmup(self):\n", |
| 147 | + " input_data = [{\"body\": [\"hello world\"] * 10}]\n", |
| 148 | + " self.handle(input_data, None)\n", |
| 149 | + "\n", |
| 150 | + " def _preprocess(self, requests):\n", |
140 | 151 | " inputSentence = []\n", |
141 | 152 | " batch_idx = []\n", |
| 153 | + " formats = [] # per-text format: \"word\" or \"token_id\"\n", |
142 | 154 | "\n", |
143 | 155 | " for request in requests:\n", |
144 | 156 | " request_body = request.get(\"body\")\n", |
145 | 157 | " if isinstance(request_body, bytearray):\n", |
146 | 158 | " request_body = request_body.decode(\"utf-8\")\n", |
147 | 159 | " request_body = json.loads((request_body))\n", |
148 | | - " if isinstance(request_body, list):\n", |
| 160 | + "\n", |
| 161 | + " # dict-based new schema: {\"texts\": str | list[str], \"sparse_embedding_format\": str}\n", |
| 162 | + " if isinstance(request_body, dict):\n", |
| 163 | + " texts = request_body.get(\"texts\")\n", |
| 164 | + " fmt = request_body.get(\"sparse_embedding_format\", \"word\")\n", |
| 165 | + " fmt = \"token_id\" if isinstance(fmt, str) and fmt.lower() == \"token_id\" else \"word\"\n", |
| 166 | + "\n", |
| 167 | + " if isinstance(texts, list):\n", |
| 168 | + " inputSentence += texts\n", |
| 169 | + " batch_idx.append(len(texts))\n", |
| 170 | + " formats += [fmt] * len(texts)\n", |
| 171 | + " else:\n", |
| 172 | + " inputSentence.append(texts)\n", |
| 173 | + " batch_idx.append(1)\n", |
| 174 | + " formats.append(fmt)\n", |
| 175 | + "\n", |
| 176 | + " # legacy schemas\n", |
| 177 | + " elif isinstance(request_body, list):\n", |
149 | 178 | " inputSentence += request_body\n", |
150 | 179 | " batch_idx.append(len(request_body))\n", |
| 180 | + " formats += [\"word\"] * len(request_body)\n", |
151 | 181 | " else:\n", |
152 | 182 | " inputSentence.append(request_body)\n", |
153 | 183 | " batch_idx.append(1)\n", |
| 184 | + " formats.append(\"word\")\n", |
| 185 | + "\n", |
| 186 | + " return inputSentence, batch_idx, formats\n", |
154 | 187 | "\n", |
155 | | - " return inputSentence, batch_idx\n", |
| 188 | + " def _convert_token_ids(self, sparse_embedding):\n", |
| 189 | + " token_ids = self.model.tokenizer.convert_tokens_to_ids([x[0] for x in sparse_embedding])\n", |
| 190 | + " return [(str(token_ids[i]), sparse_embedding[i][1]) for i in range(len(token_ids))]\n", |
156 | 191 | "\n", |
157 | 192 | " def handle(self, data, context):\n", |
158 | | - " inputSentence, batch_idx = self.preprocess(data)\n", |
| 193 | + " inputSentence, batch_idx, formats = self._preprocess(data)\n", |
159 | 194 | " model_output = self.model.encode_document(inputSentence, batch_size=max_bs)\n", |
160 | | - " sparse_embedding = list(map(dict,self.model.decode(model_output)))\n", |
| 195 | + "\n", |
| 196 | + " sparse_embedding_word = self.model.decode(model_output)\n", |
| 197 | + " for i, fmt in enumerate(formats):\n", |
| 198 | + " if fmt == \"token_id\":\n", |
| 199 | + " sparse_embedding_word[i] = self._convert_token_ids(sparse_embedding_word[i])\n", |
| 200 | + " sparse_embedding = list(map(dict, sparse_embedding_word))\n", |
161 | 201 | "\n", |
162 | 202 | " outputs = [sparse_embedding[s:e]\n", |
163 | 203 | " for s, e in zip([0]+list(itertools.accumulate(batch_idx))[:-1],\n", |
|
424 | 464 | "```json\n", |
425 | 465 | "POST /_plugins/_ml/connectors/_create\n", |
426 | 466 | "{\n", |
427 | | - " \"name\": \"test\",\n", |
428 | | - " \"description\": \"Test connector for Sagemaker model\",\n", |
| 467 | + " \"name\": \"Sagemaker Connector: embedding\",\n", |
| 468 | + " \"description\": \"The connector to sagemaker embedding model\",\n", |
429 | 469 | " \"version\": 1,\n", |
430 | 470 | " \"protocol\": \"aws_sigv4\",\n", |
431 | 471 | " \"credential\": {\n", |
|
436 | 476 | " \"region\": \"{region}\",\n", |
437 | 477 | " \"service_name\": \"sagemaker\",\n", |
438 | 478 | " \"input_docs_processed_step_size\": 2,\n", |
| 479 | + " \"sparse_embedding_format\": \"word\"\n", |
439 | 480 | " },\n", |
440 | 481 | " \"actions\": [\n", |
441 | 482 | " {\n", |
|
445 | 486 | " \"content-type\": \"application/json\"\n", |
446 | 487 | " },\n", |
447 | 488 | " \"url\": \"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{predictor.endpoint_name}/invocations\",\n", |
448 | | - " \"request_body\": \"${parameters.input}\"\n", |
| 489 | + " \"request_body\": \"\"\"\n", |
| 490 | + " {\n", |
| 491 | + " \"texts\": ${parameters.input},\n", |
| 492 | + " \"sparse_embedding_format\": \"${parameters.sparse_embedding_format}\"\n", |
| 493 | + " }\n", |
| 494 | + " \"\"\"\n", |
449 | 495 | " }\n", |
450 | 496 | " ],\n", |
451 | 497 | " \"client_config\":{\n", |
|
0 commit comments