@@ -317,6 +317,13 @@ async def query_media(
317317 int | None ,
318318 Field (description = "Match a TMDb identifier" , examples = [568467 ]),
319319 ] = None ,
320+ similar_to : Annotated [
321+ str | Sequence [str ] | None ,
322+ Field (
323+ description = "Recommend candidates similar to these identifiers" ,
324+ examples = [["49915" ], "tt8367814" ],
325+ ),
326+ ] = None ,
320327 limit : Annotated [
321328 int ,
322329 Field (
@@ -337,24 +344,37 @@ def _listify(value: Sequence[str] | str | None) -> list[str]:
337344 return [v for v in value if isinstance (v , str ) and v ]
338345
339346 vector_queries : list [tuple [str , models .Document ]] = []
340- if dense_query :
341- vector_queries .append (
342- (
343- "dense" ,
344- models .Document (
345- text = dense_query , model = server .settings .dense_model
346- ),
347+ positive_point_ids : list [Any ] = []
348+ similar_identifiers = _listify (similar_to )
349+ if similar_identifiers :
350+ for identifier in similar_identifiers :
351+ records = await media_helpers ._find_records (
352+ server , identifier , limit = 1
347353 )
348- )
349- if sparse_query :
350- vector_queries .append (
351- (
352- "sparse" ,
353- models .Document (
354- text = sparse_query , model = server .settings .sparse_model
355- ),
354+ for record in records :
355+ if record .id is not None :
356+ positive_point_ids .append (record .id )
357+ if not positive_point_ids :
358+ return []
359+ if not positive_point_ids :
360+ if dense_query :
361+ vector_queries .append (
362+ (
363+ "dense" ,
364+ models .Document (
365+ text = dense_query , model = server .settings .dense_model
366+ ),
367+ )
368+ )
369+ if sparse_query :
370+ vector_queries .append (
371+ (
372+ "sparse" ,
373+ models .Document (
374+ text = sparse_query , model = server .settings .sparse_model
375+ ),
376+ )
356377 )
357- )
358378
359379 must : list [models .FieldCondition ] = []
360380 keyword_prefetch_conditions : list [models .FieldCondition ] = []
@@ -503,7 +523,20 @@ def _listify(value: Sequence[str] | str | None) -> list[str]:
503523 query_obj : models .Query | None = None
504524 using_param : str | None = None
505525 prefetch_param : Sequence [models .Prefetch ] | None = None
506- if vector_queries :
526+ prefetch_entries : list [models .Prefetch ] = []
527+ if positive_point_ids :
528+ recommend_query = models .RecommendQuery (
529+ recommend = models .RecommendInput (positive = positive_point_ids )
530+ )
531+ prefetch_entries .append (
532+ models .Prefetch (
533+ query = recommend_query ,
534+ using = "dense" ,
535+ limit = limit ,
536+ filter = prefetch_filter ,
537+ )
538+ )
539+ if not positive_point_ids and vector_queries :
507540 candidate_limit = limit * 3 if len (vector_queries ) > 1 else limit
508541 prefetch_entries = [
509542 models .Prefetch (
@@ -514,6 +547,8 @@ def _listify(value: Sequence[str] | str | None) -> list[str]:
514547 )
515548 for name , doc in vector_queries
516549 ]
550+
551+ if prefetch_entries :
517552 if len (prefetch_entries ) > 1 :
518553 query_obj = models .FusionQuery (fusion = models .Fusion .RRF )
519554 using_param = None
0 commit comments