diff --git a/examples/networkdisk-rich-with-helper.py b/examples/networkdisk-rich-with-helper.py index 497e51ba11d24e9beed6d72c83390eb28557f972..d76efe3d00603249a60bd99b973fddf90ccbb392 100644 --- a/examples/networkdisk-rich-with-helper.py +++ b/examples/networkdisk-rich-with-helper.py @@ -81,7 +81,7 @@ class networkdisk: edge_data = schema.relations.edge_data rel0 = edges.left_join( edge_data, on=edges.columns.id.eq(edge_data.columns.edge) - ).select(where=lambda f: f.columns.key.neq("weight")) + ).select((*edges.columns[:3], "key", "value"), where=lambda f: f.columns.key.neq("weight")) rel1 = edges.select( columns=(*edges.columns[:3], make_column("weight"), edges.columns.weight), aliases={3: "key", 4: "value"}, diff --git a/querybuilder/atoms/clauses.py b/querybuilder/atoms/clauses.py index 25f7b03b17b3d9b0e889727ee53891d32d9b4d4a..1f120eb13467466f6c7cbb38fe3c8f4171d5247e 100644 --- a/querybuilder/atoms/clauses.py +++ b/querybuilder/atoms/clauses.py @@ -194,7 +194,9 @@ class WithClause(ClauseWrapper): class SetCombinatorWrapper(ClauseWrapper): __slots__ = ("combinator", "all") - def __init__(self, combinator: qbconstants.SetCombinator, all: bool = False): + def __init__( + self, combinator: qbconstants.SetCombinator, all: Optional[bool] = None + ): self.combinator = combinator self.all = all super().__init__() diff --git a/querybuilder/drivers/sql/tokenizer.py b/querybuilder/drivers/sql/tokenizer.py index b9ae9cf30b6b19133ab7602075c029980ac55e56..768e51c313c23a6cc0f0b62a07d6afc39ae67620 100644 --- a/querybuilder/drivers/sql/tokenizer.py +++ b/querybuilder/drivers/sql/tokenizer.py @@ -424,7 +424,7 @@ class Tokenizer: /, *, combinator: qbconstants.SetCombinator, - all: bool, + all: Optional[bool] = None, ) -> TkTree: comb = f"{combinator.value}" if all: diff --git a/querybuilder/drivers/sqlite/specifics.py b/querybuilder/drivers/sqlite/specifics.py new file mode 100644 index 0000000000000000000000000000000000000000..653a29ac44bddb609fa0afd0ec174d533e3b678d --- /dev/null +++ b/querybuilder/drivers/sqlite/specifics.py @@ -0,0 +1,55 @@ +import querybuilder as qb + + +class ScopedDQL(qb.atoms.relations.Fromable): + """A DQL query nested within parenthesis + + Be aware that, since ScopedDQL are sqlite specifics, they cannot be tokenized using + the standard-SQL tokenizer (namely, qb.drivers.sql.tokenizer.Tokenizer). This will + make the pretty printing as well as the stringification (`__str__`) fail, by + default. Consider setting the default tokenizer to the sqlite one (namely, + `qb.drivers.sqlite.tokenizer.Tokenizer`) so that both pretty printing and + stringification would work (see examples below). + + + Parameters + ---------- + query: qb.queries.dql.DQLQuery + the wrapped DQL query. + + kwargs: dict[str, Any] + additional keyworded parameters for super initialization. + + Examples + -------- + For pretty printing and stringification (`__str__`) of atoms to work properly, we + need to set the default tokenizer of querybuilder to the sqlite one: + >>> qb.settings["tokenizer"] = qb.drivers.sqlite.tokenizer.Tokenizer() + + Now, we can use our sqlite-specific ScopedDQL class: + >>> from querybuilder.helpers import make_column + >>> q01 = qb.queries.dql.Select([make_column(0), make_column(1)]) + >>> r01 = ScopedDQL(q01) + >>> str(r01) + '(SELECT 0, 1)' + >>> q12 = qb.queries.dql.Select([make_column(1), make_column(2)]) + >>> r12 = ScopedDQL(q12) + >>> str(r12) + '(SELECT 1, 2)' + >>> str(q12.union_all(ScopedDQL(q01.except_(q01)))) + 'SELECT 1, 2 UNION ALL (SELECT 0, 1 EXCEPT SELECT 0, 1)' + >>> r = qb.atoms.relations.CartesianProduct([r01, r12]) + >>> str(r.select("*")) + 'SELECT * FROM (SELECT 0, 1), (SELECT 1, 2)' + """ + + __slots__ = ("query",) + + def __init__(self, query: qb.queries.dql.DQLQuery, **kwargs): + self.query = query + super().__init__(columns=query.columns, **kwargs) + + def _get_subtokenize_kwargs(self, tokenizer): + kwargs = super()._get_subtokenize_kwargs(tokenizer) + kwargs["query"] = self.query.subtokenize(tokenizer, scoped=True) + return kwargs diff --git a/querybuilder/drivers/sqlite/tokenizer.py b/querybuilder/drivers/sqlite/tokenizer.py index f600451103dd241b9fd9fed15e524fa528feedc0..139730b78e90eda6dcfc99300be63174aafb044d 100644 --- a/querybuilder/drivers/sqlite/tokenizer.py +++ b/querybuilder/drivers/sqlite/tokenizer.py @@ -6,7 +6,8 @@ from typing import Callable, List, Optional, cast from querybuilder.utils.decorators import TypeDispatch from querybuilder.formatting.tokentree import TkInd, TkTree, TkSeq, TkStr from querybuilder.drivers.sql.tokenizer import Tokenizer as sqlTokenizer, name_and_defer -from querybuilder.atoms.clauses import Limit, OrderColumn +from querybuilder.drivers.sqlite.specifics import ScopedDQL +from querybuilder.atoms.clauses import Limit, OrderColumn, SetCombinatorWrapper from querybuilder.atoms.atoms import Atom import querybuilder.atoms.relations as qbrelations import querybuilder.atoms.pseudo_columns as qbpseudo_columns @@ -52,6 +53,10 @@ class Tokenizer(sqlTokenizer): return super().__call__(*args, **kwargs) # DQL + @__call__.register(ScopedDQL) + def _(self, obj: qb.queries.dql.Select, /, *, query: TkTree): + return query + @__call__.register(Limit) def _( self, @@ -89,6 +94,22 @@ class Tokenizer(sqlTokenizer): newtree.append(self.tokenize_keyword("DESC")) return tuple(newtree) + @__call__.register(SetCombinatorWrapper) + def _( + self, + obj: qb.atoms.clauses.SetCombinatorWrapper, + /, + *, + combinator: qb.utils.constants.SetCombinator, + all: Optional[bool] = None, + ) -> TkTree: + # sqlite does not support INTERSECT ALL / EXCEPT ALL + if all and combinator != qb.utils.constants.SetCombinator.UNION: + raise ValueError(f"sqlite does not support {combinator} ALL") + # sqlite does not support INTERSECT DISTINCT / EXCEPT DISTINCT / UNION DISTINCT + # ⟶ fallback to INTERSECT / EXCEPT / UNION + return super().__call__(obj, combinator=combinator, all=all or None) + # TCL @__call__.register(qb.queries.tcl.Start) def _( @@ -234,4 +255,19 @@ class Tokenizer(sqlTokenizer): return rel.alias(obj.name) - del _t + @transform.register(qb.queries.dql.SetCombination) + def _t2(self, obj: qb.queries.dql.SetCombination) -> qb.queries.dql.SetCombination: + left = right = None + if isinstance(obj.left, qb.queries.dql.SetCombination): + left = self.transform(obj.left) + if isinstance(left, qb.queries.dql.SetCombination): + left = ScopedDQL(query=left).select([qb.atoms.pseudo_columns.Star()]) + if isinstance(obj.right, qb.queries.dql.SetCombination): + right = self.transform(obj.right) + if isinstance(right, qb.queries.dql.SetCombination): + right = ScopedDQL(query=right).select([qb.atoms.pseudo_columns.Star()]) + if left or right: + obj = obj.buildfrom(obj, left=left or obj.left, right=right or obj.right) + return obj + + del _t, _t2 diff --git a/querybuilder/queries/dql.py b/querybuilder/queries/dql.py index b9da87b21dbfb6d8127b8ea8fe1c8e73f42acac6..864b71c3c6cacbf842e7e930c284c8f790e6a237 100644 --- a/querybuilder/queries/dql.py +++ b/querybuilder/queries/dql.py @@ -545,7 +545,7 @@ class SetCombination(DQLQuery): Additional keyworded parameters for super initialization. """ - __slots__ = ("combinator", "all", "subrelations") + __slots__ = ("combinator", "all", "left", "right") _scopable = True def __init__( @@ -553,41 +553,36 @@ class SetCombination(DQLQuery): combinator: qbconstants.SetCombinator, left: DQLQuery, right: DQLQuery, - all: bool = False, + all: Optional[bool] = None, **kwargs, ): + assert left.arity == right.arity self.combinator = combinator self.all = all - self.subrelations = (left, right) + self.left = left + self.right = right super().__init__(columns=left.columns, **kwargs) - @property - def left(self): - return self.subrelations[0] - - @property - def right(self): - return self.subrelations[1] - @method_accepting_lambdas def set_aliases(self, aliases): - return self.buildfrom( - self, subrelations=(self.left.set_aliases(aliases), self.right) - ) + return self.buildfrom(self, left=self.left.set_aliases(aliases)) def _post_getstate(self, state): state.pop("columns") return state def __setstate__(self, state): - state.setdefault("columns", state["subrelations"][0].columns) + state.setdefault("columns", state["left"].columns) super().__setstate__(state) def _get_subtokenize_kwargs(self, tokenizer): + scope_left = isinstance(self.left, SetCombination) + scope_right = isinstance(self.right, SetCombination) + return dict( subrelations=( - self.left.subtokenize(tokenizer), - self.right.subtokenize(tokenizer), + self.left.subtokenize(tokenizer, scoped=scope_left), + self.right.subtokenize(tokenizer, scoped=scope_right), ), combinator=qb.atoms.clauses.SetCombinatorWrapper( self.combinator, all=self.all @@ -598,7 +593,7 @@ class SetCombination(DQLQuery): def _substitute(self, substitutions: Mapping) -> Self: left = self.left.substitute(substitutions) right = self.right.substitute(substitutions) - return self.buildfrom(self, subrelations=(left, right)) + return self.buildfrom(self, left=left, right=right) class WithClosure(DQLQuery): diff --git a/querybuilder/tests/drivers/sql/test_tokenizer.py b/querybuilder/tests/drivers/sql/test_tokenizer.py index 6be81c6afb8cc3cda8311e4abc26dbf8ae970514..6ee088bdddfe487840216f4ef69e0b6308a52f5d 100644 --- a/querybuilder/tests/drivers/sql/test_tokenizer.py +++ b/querybuilder/tests/drivers/sql/test_tokenizer.py @@ -1063,6 +1063,29 @@ class TestSQLTokenizer: assert expected == result + def test_set_combinator_wrapper_with_None_all(self): + combinator = qbconstants.SetCombinator.UNION + result = self.tk( + self.get_empty_instance(qb.atoms.clauses.SetCombinatorWrapper), + combinator=combinator, + all=None, + ) + + expected = (TkSeq((TkStr(qbtoken.Keyword, combinator.value),)),) + + assert expected == result + + def test_set_combinator_wrapper_without_all(self): + combinator = qbconstants.SetCombinator.UNION + result = self.tk( + self.get_empty_instance(qb.atoms.clauses.SetCombinatorWrapper), + combinator=combinator, + ) + + expected = (TkSeq((TkStr(qbtoken.Keyword, combinator.value),)),) + + assert expected == result + def test_set_columns(self): set_columns = self.get_dummy_columns(3) diff --git a/querybuilder/tests/drivers/sqlite/test_sqlite_specifics.py b/querybuilder/tests/drivers/sqlite/test_sqlite_specifics.py new file mode 100644 index 0000000000000000000000000000000000000000..f294a6afcf1353b255c656c80560e9e564cd756b --- /dev/null +++ b/querybuilder/tests/drivers/sqlite/test_sqlite_specifics.py @@ -0,0 +1,20 @@ +from mock import Mock +import querybuilder.drivers.sqlite.specifics as specifics + + +class TestScopedDQL: + def test_get_subtokenize_kwargs(self): + query_tok = "query" + tk = Mock() + query = Mock() + query.subtokenize = Mock(return_value=query_tok) + query.columns = [] + + scoped = specifics.ScopedDQL(query) + + kwargs = scoped._get_subtokenize_kwargs(tk) + + expected_kwargs = {"query": query_tok} + + assert expected_kwargs == kwargs + query.subtokenize.assert_called_with(tk, scoped=True) diff --git a/querybuilder/tests/drivers/sqlite/test_sqlite_tokenizer.py b/querybuilder/tests/drivers/sqlite/test_sqlite_tokenizer.py index 4216e571a4263205f332072023c02d369fe6adbd..d262e07b0c7640bf6d62f16c7ce967d4b4c23efb 100644 --- a/querybuilder/tests/drivers/sqlite/test_sqlite_tokenizer.py +++ b/querybuilder/tests/drivers/sqlite/test_sqlite_tokenizer.py @@ -145,3 +145,25 @@ class TestSqliteTokenizer(parent_suite.TestSQLTokenizer): assert rel.name == transrel.name assert rel.schema_name == transrel.schema_name + + def test_set_combinator_wrapper_with_False_all(self): + # sqlite does not support UNION/INTERSECT/EXCEPT DISTINCT + # ⟶ fallback to UNION/INTERSECT/EXCEPT + combinator = qb.utils.constants.SetCombinator.UNION + result = self.tk( + self.get_empty_instance(qb.atoms.clauses.SetCombinatorWrapper), + combinator=combinator, + ) + + expected = (TkSeq((TkStr(qbtoken.Keyword, combinator.value),)),) + + assert expected == result + + def test_scoped_dql(self): + query = self.get_dummy_tkseq("query") + result = self.tk( + self.get_empty_instance(qb.drivers.sqlite.specifics.ScopedDQL), + query=query, + ) + + assert query == result diff --git a/querybuilder/tests/queries/test_dql.py b/querybuilder/tests/queries/test_dql.py index 5eab126f0ce86ec5944c45e711fca1df45820328..88ecd3a784ad8a7fa3ff6dd4e9215c199a83fe52 100644 --- a/querybuilder/tests/queries/test_dql.py +++ b/querybuilder/tests/queries/test_dql.py @@ -1,5 +1,9 @@ import pytest +from mock import Mock +from querybuilder.tests.utils import create_subtokenizable_mock + from querybuilder.helpers import make_column +from querybuilder.utils.constants import SetCombinator import querybuilder.atoms.pseudo_columns as qbpseudo_columns import querybuilder.atoms.columns as qbcolumns import querybuilder.atoms.relations as qbrelations @@ -28,8 +32,6 @@ class TestDQL: class TestSelect(TestDQL): - tk = qb.drivers.sql.tokenizer.Tokenizer() - def get_dql_queries(self): columns = [qbcolumns.Named(int, f"c{i}") for i in range(5)] yield dql.Select(columns) @@ -566,6 +568,71 @@ class TestSetCombination(TestDQL): assert post_query == post_comb.left assert post_query == post_comb.right + @pytest.mark.parametrize( + "combinator, all", [(x, y) for x in SetCombinator for y in [None, True, False]] + ) + def test_simple_set_combination(self, combinator, all): + left = dql.Select([make_column(0), make_column(1)]) + right = dql.Select([make_column(1), make_column(2)]) + q = dql.SetCombination(combinator, left, right, all=all) + assert q.combinator == combinator + assert q.arity == left.arity == right.arity + assert q.left == left + assert q.right == right + assert q.all == all + + @pytest.mark.parametrize( + "l_cls, r_cls, l_scope, r_scope", + [ + (qb.queries.dql.SetCombination, qb.queries.dql.SetCombination, True, True), + (qb.queries.dql.SetCombination, qb.queries.dql.Select, True, False), + (qb.queries.dql.Select, qb.queries.dql.SetCombination, False, True), + (qb.queries.dql.Select, qb.queries.dql.Select, False, False), + ], + ) + def test_get_subtokenize_kwargs(self, mocker, l_cls, r_cls, l_scope, r_scope): + left_tok = "left" + right_tok = "right" + comb_tok = "comb" + + comb = qb.utils.constants.SetCombinator.UNION + all = True + + left = Mock() + left.__class__ = l_cls + left.subtokenize = Mock(return_value=left_tok) + left.arity = 2 + left.columns = (Mock(), Mock()) + + right = Mock() + right.__class__ = r_cls + right.subtokenize = Mock(return_value=right_tok) + right.arity = 2 + right.columns = (Mock(), Mock()) + + tokenizer = Mock() + + create_subtokenizable_mock( + mocker, "querybuilder.atoms.clauses.SetCombinatorWrapper", comb_tok + ) + + rel = qb.queries.dql.SetCombination(comb, left, right, all=all) + + expected = { + "subrelations": (left_tok, right_tok), + "combinator": comb_tok, + } + + result = rel._get_subtokenize_kwargs(tokenizer) + + assert expected == result + + left.subtokenize.assert_called_once_with(tokenizer, scoped=l_scope) + right.subtokenize.assert_called_once_with(tokenizer, scoped=r_scope) + + SCW = qb.atoms.clauses.SetCombinatorWrapper + SCW.assert_called_with(comb, all=all) + class TestWithClosure: def test_substitute(self):