diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 3b6d8496..0295568f 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -120,11 +120,13 @@ class Tag(FilterField): FilterOperator.EQ: "==", FilterOperator.NE: "!=", FilterOperator.IN: "==", + FilterOperator.LIKE: "%", } OPERATOR_MAP: Dict[FilterOperator, str] = { FilterOperator.EQ: "@%s:{%s}", FilterOperator.NE: "(-@%s:{%s})", FilterOperator.IN: "@%s:{%s}", + FilterOperator.LIKE: "@%s:{%s}", } SUPPORTED_VAL_TYPES = (list, set, tuple, str, type(None)) @@ -177,9 +179,41 @@ def __ne__(self, other) -> "FilterExpression": self._set_tag_value(other, FilterOperator.NE) return FilterExpression(str(self)) + def __mod__(self, other: Union[List[str], str]) -> "FilterExpression": + """Create a Tag wildcard filter expression for pattern matching. + + This enables wildcard pattern matching on tag fields using the ``*`` + character. Unlike the equality operator, wildcards are not escaped, + allowing patterns with wildcards in any position, such as prefix + (``"tech*"``), suffix (``"*tech"``), or middle (``"*tech*"``) + matches. + + Args: + other (Union[List[str], str]): The tag pattern(s) to filter on. + Use ``*`` for wildcard matching (e.g., ``"tech*"``, ``"*tech"``, + or ``"*tech*"``). + + .. code-block:: python + + from redisvl.query.filter import Tag + + f = Tag("category") % "tech*" # Prefix match + f = Tag("category") % "*tech" # Suffix match + f = Tag("category") % "*tech*" # Contains match + f = Tag("category") % "elec*|*soft" # Multiple wildcard patterns + f = Tag("category") % ["tech*", "*science"] # List of patterns + + """ + self._set_tag_value(other, FilterOperator.LIKE) + return FilterExpression(str(self)) + @property def _formatted_tag_value(self) -> str: - return "|".join([self.escaper.escape(tag) for tag in self._value]) + # For LIKE operator, preserve wildcards (*) in the pattern + preserve_wildcards = self._operator == FilterOperator.LIKE + return "|".join( + [self.escaper.escape(tag, preserve_wildcards) for tag in self._value] + ) def __str__(self) -> str: """Return the Redis Query string for the Tag filter""" diff --git a/redisvl/utils/token_escaper.py b/redisvl/utils/token_escaper.py index 53e47a73..04e04cd2 100644 --- a/redisvl/utils/token_escaper.py +++ b/redisvl/utils/token_escaper.py @@ -12,13 +12,30 @@ class TokenEscaper: # Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" + # Same as above but excludes * to allow wildcard patterns + ESCAPED_CHARS_NO_WILDCARD = r"[,.<>{}\[\]\\\"\':;!@#$%^&()\-+=~\/ ]" + def __init__(self, escape_chars_re: Optional[Pattern] = None): if escape_chars_re: self.escaped_chars_re = escape_chars_re else: self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) + self.escaped_chars_no_wildcard_re = re.compile(self.ESCAPED_CHARS_NO_WILDCARD) + + def escape(self, value: str, preserve_wildcards: bool = False) -> str: + """Escape special characters in a string for use in Redis queries. + + Args: + value: The string value to escape. + preserve_wildcards: If True, preserves * characters for wildcard + matching. Defaults to False. + + Returns: + The escaped string. - def escape(self, value: str) -> str: + Raises: + TypeError: If value is not a string. + """ if not isinstance(value, str): raise TypeError( f"Value must be a string object for token escaping, got type {type(value)}" @@ -28,4 +45,6 @@ def escape_symbol(match): value = match.group(0) return f"\\{value}" + if preserve_wildcards: + return self.escaped_chars_no_wildcard_re.sub(escape_symbol, value) return self.escaped_chars_re.sub(escape_symbol, value) diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py index b4890028..eb704817 100644 --- a/tests/unit/test_filter.py +++ b/tests/unit/test_filter.py @@ -55,6 +55,81 @@ def test_tag_filter_varied(operation, tags, expected): assert str(tf) == expected +@pytest.mark.parametrize( + "pattern,expected", + [ + # Basic prefix wildcard + ("tech*", "@tag_field:{tech*}"), + # Multiple patterns via list + (["tech*", "soft*"], "@tag_field:{tech*|soft*}"), + # Wildcard with special chars that still get escaped + ("tech*-pro", "@tag_field:{tech*\\-pro}"), + # Prefix with space (space escaped, wildcard preserved) + ("hello w*", "@tag_field:{hello\\ w*}"), + # Multiple wildcards in same pattern + ("*test*", "@tag_field:{*test*}"), + # Empty pattern returns wildcard match-all + ("", "*"), + ([], "*"), + (None, "*"), + # Pattern with special characters + ("cat$*", "@tag_field:{cat\\$*}"), + ], + ids=[ + "prefix_wildcard", + "multiple_patterns", + "wildcard_with_special_char", + "prefix_with_space", + "multiple_wildcards", + "empty_string", + "empty_list", + "none", + "special_char_with_wildcard", + ], +) +def test_tag_wildcard_filter(pattern, expected): + """Test Tag % operator for wildcard/prefix matching.""" + tf = Tag("tag_field") % pattern + assert str(tf) == expected + + +def test_tag_wildcard_preserves_asterisk(): + """Verify that * is not escaped when using % operator.""" + # With == operator, * should be escaped + tf_eq = Tag("tag_field") == "tech*" + assert str(tf_eq) == "@tag_field:{tech\\*}" + + # With % operator, * should NOT be escaped + tf_like = Tag("tag_field") % "tech*" + assert str(tf_like) == "@tag_field:{tech*}" + + +def test_tag_wildcard_combined_with_exact_match(): + """Test combining wildcard and exact match Tag filters in the same query.""" + # Create filters with different operators + exact_match = Tag("brand") == "nike" + wildcard_match = Tag("category") % "tech*" + + # Verify individual filters work correctly + assert str(exact_match) == "@brand:{nike}" + assert str(wildcard_match) == "@category:{tech*}" + + # Combine with AND - wildcard should be preserved, exact match should not have * + combined_and = exact_match & wildcard_match + assert str(combined_and) == "(@brand:{nike} @category:{tech*})" + + # Combine with OR + combined_or = exact_match | wildcard_match + assert str(combined_or) == "(@brand:{nike} | @category:{tech*})" + + # More complex: mix of exact, wildcard, and exact with * in value + exact_with_asterisk = Tag("status") == "active*" # * should be escaped + complex_filter = exact_match & wildcard_match & exact_with_asterisk + assert "@brand:{nike}" in str(complex_filter) + assert "@category:{tech*}" in str(complex_filter) # wildcard preserved + assert "@status:{active\\*}" in str(complex_filter) # asterisk escaped + + def test_nullable(): tag = Tag("tag_field") == None assert str(tag) == "*"