diff --git a/querybuilder/helpers/schema.py b/querybuilder/helpers/schema.py index 064e8761c188e232dd6777fb7ecbe6c339e497cc..759cad01a62ee69f04c5432ac23be27e6493d1ed 100644 --- a/querybuilder/helpers/schema.py +++ b/querybuilder/helpers/schema.py @@ -320,25 +320,26 @@ class ColumnSpec: def to_named_column( self: NamedColumn | ColumnSpec, - relation_name: Optional[str] = None, - schema_name: Optional[str] = None, + relation_name: str | None | _MISSING_TYPE = MISSING, + schema_name: str | None | _MISSING_TYPE = MISSING, ) -> NamedColumn: assert self.name assert self.sqltype - assert ( - not self.relation_name - or not relation_name - or self.relation_name == relation_name - ) - assert ( - not self.schema_name or not schema_name or self.schema_name == schema_name - ) - return NamedColumn( - self.sqltype, - self.name, - self.relation_name or relation_name, - self.schema_name or schema_name, - ) + if relation_name is MISSING: + relation_name = self.relation_name + elif relation_name is None: + schema_name = None + relation_name = cast(Optional[str], relation_name) + if schema_name is MISSING: + schema_name = self.schema_name + schema_name = cast(Optional[str], schema_name) + if relation_name: + assert not self.relation_name or self.relation_name == relation_name + if schema_name: + assert relation_name + assert not self.schema_name or self.schema_name == schema_name + + return NamedColumn(self.sqltype, self.name, relation_name, schema_name) @classmethod def _resolve_str_spec( @@ -467,7 +468,7 @@ class ColumnSpec: check: Optional[str | Column | qb.atoms.constraints.ColumnCheck | tuple | dict] if callable(self.check): - check = self.check(self.to_named_column(), rel) + check = self.check(self.to_named_column(relation_name=None), rel) else: check = self.check if check is not None: @@ -497,6 +498,7 @@ class ColumnSpec: qb.atoms.constraints.ColumnDefault, ] if callable(self.default): + # TODO: when is a callable default useful? default = self.default(self.to_named_column()) else: default = self.default