Skip to content

Commit 5139cff

Browse files
authored
Merge branch 'main' into fix/claude-v3-bedrock-support
2 parents ef006b9 + f5510c9 commit 5139cff

File tree

1 file changed

+55
-9
lines changed

1 file changed

+55
-9
lines changed

docs/model_serving_framework/deploy_sparse_model_to_SageMaker.ipynb

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@
5252
"outputs": [],
5353
"source": [
5454
"%%writefile handler/code/requirements.txt\n",
55-
"sentence-transformers==5.0.0"
55+
"transformers==4.56.1\n",
56+
"huggingface_hub==0.35.0\n",
57+
"hf_xet==1.1.10\n",
58+
"tokenizers==0.22.0\n",
59+
"regex==2025.9.1\n",
60+
"safetensors==0.6.2\n",
61+
"sentence-transformers==5.1.0"
5662
]
5763
},
5864
{
@@ -134,30 +140,64 @@
134140
" )\n",
135141
" print(f\"Using device: {self.device}\")\n",
136142
" self.model = SparseEncoder(model_id, device=self.device, trust_remote_code=trust_remote_code)\n",
143+
" self._warmup()\n",
137144
" self.initialized = True\n",
138145
"\n",
139-
" def preprocess(self, requests):\n",
146+
" def _warmup(self):\n",
147+
" input_data = [{\"body\": [\"hello world\"] * 10}]\n",
148+
" self.handle(input_data, None)\n",
149+
"\n",
150+
" def _preprocess(self, requests):\n",
140151
" inputSentence = []\n",
141152
" batch_idx = []\n",
153+
" formats = [] # per-text format: \"word\" or \"token_id\"\n",
142154
"\n",
143155
" for request in requests:\n",
144156
" request_body = request.get(\"body\")\n",
145157
" if isinstance(request_body, bytearray):\n",
146158
" request_body = request_body.decode(\"utf-8\")\n",
147159
" request_body = json.loads((request_body))\n",
148-
" if isinstance(request_body, list):\n",
160+
"\n",
161+
" # dict-based new schema: {\"texts\": str | list[str], \"sparse_embedding_format\": str}\n",
162+
" if isinstance(request_body, dict):\n",
163+
" texts = request_body.get(\"texts\")\n",
164+
" fmt = request_body.get(\"sparse_embedding_format\", \"word\")\n",
165+
" fmt = \"token_id\" if isinstance(fmt, str) and fmt.lower() == \"token_id\" else \"word\"\n",
166+
"\n",
167+
" if isinstance(texts, list):\n",
168+
" inputSentence += texts\n",
169+
" batch_idx.append(len(texts))\n",
170+
" formats += [fmt] * len(texts)\n",
171+
" else:\n",
172+
" inputSentence.append(texts)\n",
173+
" batch_idx.append(1)\n",
174+
" formats.append(fmt)\n",
175+
"\n",
176+
" # legacy schemas\n",
177+
" elif isinstance(request_body, list):\n",
149178
" inputSentence += request_body\n",
150179
" batch_idx.append(len(request_body))\n",
180+
" formats += [\"word\"] * len(request_body)\n",
151181
" else:\n",
152182
" inputSentence.append(request_body)\n",
153183
" batch_idx.append(1)\n",
184+
" formats.append(\"word\")\n",
185+
"\n",
186+
" return inputSentence, batch_idx, formats\n",
154187
"\n",
155-
" return inputSentence, batch_idx\n",
188+
" def _convert_token_ids(self, sparse_embedding):\n",
189+
" token_ids = self.model.tokenizer.convert_tokens_to_ids([x[0] for x in sparse_embedding])\n",
190+
" return [(str(token_ids[i]), sparse_embedding[i][1]) for i in range(len(token_ids))]\n",
156191
"\n",
157192
" def handle(self, data, context):\n",
158-
" inputSentence, batch_idx = self.preprocess(data)\n",
193+
" inputSentence, batch_idx, formats = self._preprocess(data)\n",
159194
" model_output = self.model.encode_document(inputSentence, batch_size=max_bs)\n",
160-
" sparse_embedding = list(map(dict,self.model.decode(model_output)))\n",
195+
"\n",
196+
" sparse_embedding_word = self.model.decode(model_output)\n",
197+
" for i, fmt in enumerate(formats):\n",
198+
" if fmt == \"token_id\":\n",
199+
" sparse_embedding_word[i] = self._convert_token_ids(sparse_embedding_word[i])\n",
200+
" sparse_embedding = list(map(dict, sparse_embedding_word))\n",
161201
"\n",
162202
" outputs = [sparse_embedding[s:e]\n",
163203
" for s, e in zip([0]+list(itertools.accumulate(batch_idx))[:-1],\n",
@@ -424,8 +464,8 @@
424464
"```json\n",
425465
"POST /_plugins/_ml/connectors/_create\n",
426466
"{\n",
427-
" \"name\": \"test\",\n",
428-
" \"description\": \"Test connector for Sagemaker model\",\n",
467+
" \"name\": \"Sagemaker Connector: embedding\",\n",
468+
" \"description\": \"The connector to sagemaker embedding model\",\n",
429469
" \"version\": 1,\n",
430470
" \"protocol\": \"aws_sigv4\",\n",
431471
" \"credential\": {\n",
@@ -436,6 +476,7 @@
436476
" \"region\": \"{region}\",\n",
437477
" \"service_name\": \"sagemaker\",\n",
438478
" \"input_docs_processed_step_size\": 2,\n",
479+
" \"sparse_embedding_format\": \"word\"\n",
439480
" },\n",
440481
" \"actions\": [\n",
441482
" {\n",
@@ -445,7 +486,12 @@
445486
" \"content-type\": \"application/json\"\n",
446487
" },\n",
447488
" \"url\": \"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{predictor.endpoint_name}/invocations\",\n",
448-
" \"request_body\": \"${parameters.input}\"\n",
489+
" \"request_body\": \"\"\"\n",
490+
" {\n",
491+
" \"texts\": ${parameters.input},\n",
492+
" \"sparse_embedding_format\": \"${parameters.sparse_embedding_format}\"\n",
493+
" }\n",
494+
" \"\"\"\n",
449495
" }\n",
450496
" ],\n",
451497
" \"client_config\":{\n",

0 commit comments

Comments
 (0)