diff --git a/tortoise/contrib/postgres/functions.py b/tortoise/contrib/postgres/functions.py index 823b4aa4c..bdefae2be 100644 --- a/tortoise/contrib/postgres/functions.py +++ b/tortoise/contrib/postgres/functions.py @@ -1,13 +1,15 @@ from pypika_tortoise.terms import Function, Term +DEFAULT_TEXT_SEARCH_CONFIG = "pg_catalog.simple" + class ToTsVector(Function): """ to to_tsvector function """ - def __init__(self, field: Term) -> None: - super().__init__("TO_TSVECTOR", field) + def __init__(self, field: Term, config_name: str = DEFAULT_TEXT_SEARCH_CONFIG) -> None: + super().__init__("TO_TSVECTOR", config_name, field) class ToTsQuery(Function): @@ -15,8 +17,8 @@ class ToTsQuery(Function): to_tsquery function """ - def __init__(self, field: Term) -> None: - super().__init__("TO_TSQUERY", field) + def __init__(self, field: Term, config_name: str = DEFAULT_TEXT_SEARCH_CONFIG) -> None: + super().__init__("TO_TSQUERY", config_name, field) class PlainToTsQuery(Function): @@ -24,8 +26,8 @@ class PlainToTsQuery(Function): plainto_tsquery function """ - def __init__(self, field: Term) -> None: - super().__init__("PLAINTO_TSQUERY", field) + def __init__(self, field: Term, config_name: str = DEFAULT_TEXT_SEARCH_CONFIG) -> None: + super().__init__("PLAINTO_TSQUERY", config_name, field) class Random(Function): diff --git a/tortoise/contrib/postgres/search.py b/tortoise/contrib/postgres/search.py index ac41d25a6..e6c4fb779 100644 --- a/tortoise/contrib/postgres/search.py +++ b/tortoise/contrib/postgres/search.py @@ -10,10 +10,8 @@ class Comp(Comparator): search = " @@ " -class SearchCriterion(BasicCriterion): +class SearchCriterion(BasicCriterion): # type: ignore def __init__(self, field: Term, expr: Union[Term, Function]) -> None: - if isinstance(expr, Function): - _expr = expr - else: - _expr = ToTsQuery(expr) - super().__init__(Comp.search, ToTsVector(field), _expr) + if not isinstance(expr, Function): + expr = ToTsQuery(expr) + super().__init__(Comp.search, ToTsVector(config_name=expr.args[0].value, field=field), expr)