Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 68856f7

Browse files
authored
Fix: Support dbt_utils.star except argument (SQLMesh#973)
1 parent bcda5e9 commit 68856f7

8 files changed

Lines changed: 112 additions & 34 deletions

File tree

sqlmesh/dbt/adapter.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def __init__(
2323
self,
2424
jinja_macros: JinjaMacroRegistry,
2525
jinja_globals: t.Optional[t.Dict[str, t.Any]] = None,
26+
dialect: t.Optional[str] = None,
2627
):
2728
self.jinja_macros = jinja_macros
2829
self.jinja_globals = jinja_globals.copy() if jinja_globals else {}
2930
self.jinja_globals["adapter"] = self
31+
self.dialect = dialect
3032

3133
@abc.abstractmethod
3234
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
@@ -71,9 +73,9 @@ def execute(
7173
) -> t.Optional[t.Tuple[AdapterResponse, agate.Table]]:
7274
"""Executes the given SQL statement and returns the results as an agate table."""
7375

74-
@abc.abstractmethod
7576
def quote(self, identifier: str) -> str:
76-
"""Returns a quoted identifeir."""
77+
"""Returns a quoted identifier."""
78+
return exp.to_column(identifier).sql(dialect=self.dialect, identify=True)
7779

7880
def dispatch(self, name: str, package: t.Optional[str] = None) -> t.Callable:
7981
"""Returns a dialect-specific version of a macro with the given name."""
@@ -124,9 +126,6 @@ def execute(
124126
) -> t.Optional[t.Tuple[AdapterResponse, agate.Table]]:
125127
return None
126128

127-
def quote(self, identifier: str) -> str:
128-
return identifier
129-
130129

131130
class RuntimeAdapter(BaseAdapter):
132131
def __init__(
@@ -137,7 +136,7 @@ def __init__(
137136
):
138137
from dbt.adapters.base.relation import Policy
139138

140-
super().__init__(jinja_macros, jinja_globals=jinja_globals)
139+
super().__init__(jinja_macros, jinja_globals=jinja_globals, dialect=engine_adapter.dialect)
141140

142141
self.engine_adapter = engine_adapter
143142
# All engines quote by default except Snowflake
@@ -244,6 +243,3 @@ def execute(
244243
assert isinstance(resp, pd.DataFrame)
245244
return AdapterResponse("Success"), pandas_to_agate(resp)
246245
return AdapterResponse("Success"), empty_table()
247-
248-
def quote(self, identifier: str) -> str:
249-
return exp.to_column(identifier).sql(dialect=self.engine_adapter.dialect, identify=True)

sqlmesh/dbt/builtin.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ def create_builtin_globals(
308308

309309
if engine_adapter is not None:
310310
adapter = RuntimeAdapter(
311-
engine_adapter, jinja_macros, jinja_globals={**builtin_globals, **jinja_globals}
311+
engine_adapter,
312+
jinja_macros,
313+
jinja_globals={**builtin_globals, **jinja_globals},
312314
)
313315
sql_execution = SQLExecution(adapter)
314316
builtin_globals.update(
@@ -325,11 +327,14 @@ def create_builtin_globals(
325327
}
326328
)
327329
else:
330+
target = jinja_globals.get("target")
328331
builtin_globals.update(
329332
{
330333
"execute": False,
331334
"adapter": ParsetimeAdapter(
332-
jinja_macros, jinja_globals={**builtin_globals, **jinja_globals}
335+
jinja_macros,
336+
jinja_globals={**builtin_globals, **jinja_globals},
337+
dialect=target.type if target else None,
333338
),
334339
"load_relation": lambda *args, **kwargs: None,
335340
"store_result": lambda *args, **kwargs: "",

sqlmesh/dbt/context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,7 @@ def copy(self) -> DbtContext:
159159
@property
160160
def jinja_environment(self) -> Environment:
161161
if self._jinja_environment is None:
162-
self._jinja_environment = self.jinja_macros.build_environment(
163-
**self.jinja_globals, engine_adapter=self.engine_adapter
164-
)
162+
self._jinja_environment = self.jinja_macros.build_environment(**self.jinja_globals)
165163
return self._jinja_environment
166164

167165
@property

sqlmesh/dbt/macros.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from pathlib import Path
2+
3+
from sqlmesh.dbt.package import MacroConfig, MacroInfo
4+
5+
6+
def dbt_utils_star() -> MacroConfig:
7+
definition = """
8+
{% macro star(from, relation_alias=False, except=[], prefix='', suffix='', quote_identifiers=True) -%}
9+
{%- if prefix != '' -%}
10+
{{ exceptions.raise_compiler_error("prefix argument not currently supported for dbt_utils.star macro") }}
11+
{%- elif suffix != '' -%}
12+
{{ exceptions.raise_compiler_error("suffix argument not currently supported for dbt_utils.star macro") }}
13+
{%- endif -%}
14+
15+
{{ from }}.*
16+
{%- if except|length > 0 %} EXCEPT (
17+
{%- for col in except -%}
18+
{%- if not loop.first %}, {% endif -%}
19+
{%- if quote_identifiers -%}{{ adapter.quote(col)|trim }}{%- else -%}{{ col|trim }}{%- endif -%}
20+
{%- endfor -%}
21+
)
22+
{%- endif -%}
23+
{% endmacro %}
24+
"""
25+
return MacroConfig(info=MacroInfo(definition=definition, depends_on=[]), path=Path())
26+
27+
28+
MACRO_OVERRIDES = {
29+
"dbt_utils": {
30+
"star": dbt_utils_star(),
31+
}
32+
}

sqlmesh/dbt/manifest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from dbt.version import get_installed_version
1818

1919
from sqlmesh.dbt.basemodel import Dependencies
20+
from sqlmesh.dbt.macros import MACRO_OVERRIDES
2021
from sqlmesh.dbt.model import ModelConfig
2122
from sqlmesh.dbt.package import MacroConfig
2223
from sqlmesh.dbt.seed import SeedConfig
@@ -119,7 +120,10 @@ def _load_macros(self) -> None:
119120
if macro.name.startswith("test_"):
120121
macro.macro_sql = _convert_jinja_test_to_macro(macro.macro_sql)
121122

122-
self._macros_per_package[macro.package_name][macro.name] = MacroConfig(
123+
package_overrides = MACRO_OVERRIDES.get(macro.package_name, {})
124+
self._macros_per_package[macro.package_name][macro.name] = package_overrides.get(
125+
macro.name
126+
) or MacroConfig(
123127
info=MacroInfo(
124128
definition=macro.macro_sql,
125129
depends_on=list(_macro_references(self._manifest, macro)),

tests/dbt/conftest.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from __future__ import annotations
22

3+
import typing as t
34
from pathlib import Path
45

56
import pytest
6-
from pytest_mock.plugin import MockerFixture
77

88
from sqlmesh.dbt.context import DbtContext
99
from sqlmesh.dbt.project import Project
1010
from tests.conftest import delete_cache
1111

1212

1313
@pytest.fixture()
14-
def sushi_test_project(mocker: MockerFixture) -> Project:
14+
def sushi_test_project() -> Project:
1515
project_root = "tests/fixtures/dbt/sushi_test"
1616
delete_cache(project_root)
1717
project = Project.load(DbtContext(project_root=Path(project_root)))
@@ -21,3 +21,18 @@ def sushi_test_project(mocker: MockerFixture) -> Project:
2121
package=package_name if package_name != project.context.project_name else None,
2222
)
2323
return project
24+
25+
26+
@pytest.fixture()
27+
def runtime_renderer() -> t.Callable:
28+
def create_renderer(context: DbtContext) -> t.Callable:
29+
environment = context.jinja_macros.build_environment(
30+
**context.jinja_globals, engine_adapter=context.engine_adapter
31+
)
32+
33+
def render(value: str) -> str:
34+
return environment.from_string(value).render()
35+
36+
return render
37+
38+
return create_renderer

tests/dbt/test_adapter.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import typing as t
4+
35
import pytest
46
from dbt.adapters.base.column import Column
57
from sqlglot import exp
@@ -8,8 +10,9 @@
810
from sqlmesh.utils.errors import ConfigError
911

1012

11-
def test_adapter_relation(sushi_test_project: Project):
13+
def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Callable):
1214
context = sushi_test_project.context
15+
renderer = runtime_renderer(context)
1316
assert context.engine_adapter
1417

1518
engine_adapter = context.engine_adapter
@@ -26,17 +29,17 @@ def test_adapter_relation(sushi_test_project: Project):
2629
)
2730

2831
assert (
29-
context.render("{{ adapter.get_relation(database=None, schema='foo', identifier='bar') }}")
32+
renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar') }}")
3033
== '"foo"."bar"'
3134
)
32-
assert context.render(
35+
assert renderer(
3336
"{%- set relation = adapter.get_relation(database=None, schema='foo', identifier='bar') -%} {{ adapter.get_columns_in_relation(relation) }}"
3437
) == str([Column.from_description(name="baz", raw_data_type="INT")])
3538

36-
assert context.render("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "2"
39+
assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "2"
3740

3841
assert (
39-
context.render(
42+
renderer(
4043
"""
4144
{%- set from = adapter.get_relation(database=None, schema='foo', identifier='bar') -%}
4245
{%- set to = adapter.get_relation(database=None, schema='foo', identifier='another') -%}
@@ -47,16 +50,17 @@ def test_adapter_relation(sushi_test_project: Project):
4750
)
4851

4952
assert (
50-
context.render(
53+
renderer(
5154
"{%- set relation = adapter.get_relation(database=None, schema='foo', identifier='bar') -%} {{ adapter.get_missing_columns(relation, relation) }}"
5255
)
5356
== "[]"
5457
)
5558

5659

57-
def test_adapter_dispatch(sushi_test_project: Project):
60+
def test_adapter_dispatch(sushi_test_project: Project, runtime_renderer: t.Callable):
5861
context = sushi_test_project.context
59-
assert context.render("{{ adapter.dispatch('current_engine', 'customers')() }}") == "duckdb"
62+
renderer = runtime_renderer(context)
63+
assert renderer("{{ adapter.dispatch('current_engine', 'customers')() }}") == "duckdb"
6064

6165
with pytest.raises(ConfigError, match=r"Macro 'current_engine'.*was not found."):
62-
context.render("{{ adapter.dispatch('current_engine')() }}")
66+
renderer("{{ adapter.dispatch('current_engine')() }}")

tests/dbt/test_transformation.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
column_types_to_sqlmesh,
2222
)
2323
from sqlmesh.dbt.context import DbtContext
24+
from sqlmesh.dbt.macros import dbt_utils_star
2425
from sqlmesh.dbt.model import Materialization, ModelConfig
2526
from sqlmesh.dbt.project import Project
2627
from sqlmesh.dbt.seed import SeedConfig
@@ -241,27 +242,30 @@ def test_this(assert_exp_eq, sushi_test_project: Project):
241242
)
242243

243244

244-
def test_statement(sushi_test_project: Project):
245+
def test_statement(sushi_test_project: Project, runtime_renderer: t.Callable):
245246
context = sushi_test_project.context
246-
assert context.render(
247-
"{% set test_var = 'SELECT 1' %}{% call statement('something', fetch_result=True) %} {{ test_var }} {% endcall %}{{ load_result('something').table }}"
247+
renderer = runtime_renderer(context)
248+
assert renderer(
249+
"{% set test_var = 'SELECT 1' %}{% call statement('something', fetch_result=True) %} {{ test_var }} {% endcall %}{{ load_result('something').table }}",
248250
) == str(agate.Table([[1]], column_names=["1"], column_types=[agate.Number()]))
249251

250252

251-
def test_run_query(sushi_test_project: Project):
253+
def test_run_query(sushi_test_project: Project, runtime_renderer: t.Callable):
252254
context = sushi_test_project.context
253-
assert context.render("{{ run_query('SELECT 1 UNION ALL SELECT 2') }}") == str(
255+
renderer = runtime_renderer(context)
256+
assert renderer("{{ run_query('SELECT 1 UNION ALL SELECT 2') }}") == str(
254257
agate.Table([[1], [2]], column_names=["1"], column_types=[agate.Number()])
255258
)
256259

257260

258-
def test_logging(capsys, sushi_test_project: Project):
261+
def test_logging(capsys, sushi_test_project: Project, runtime_renderer: t.Callable):
259262
context = sushi_test_project.context
263+
renderer = runtime_renderer(context)
260264

261-
assert context.render('{{ log("foo") }}') == ""
265+
assert renderer('{{ log("foo") }}') == ""
262266
assert "foo" in capsys.readouterr().out
263267

264-
assert context.render('{{ print("bar") }}') == ""
268+
assert renderer('{{ print("bar") }}') == ""
265269
assert "bar" in capsys.readouterr().out
266270

267271

@@ -394,3 +398,23 @@ def test_dbt_version(sushi_test_project: Project):
394398
context = sushi_test_project.context
395399

396400
assert context.render("{{ dbt_version }}").startswith("1.")
401+
402+
403+
def test_dbt_utils_star_macro(sushi_test_project: Project):
404+
context = sushi_test_project.context
405+
context.jinja_macros.add_macros({"star": dbt_utils_star().info}, "dbt_utils")
406+
context._jinja_environment = None
407+
408+
assert context.render("{{ dbt_utils.star(from='foo') }}") == "foo.*"
409+
assert (
410+
context.render("{{ dbt_utils.star(from='foo', except=['bar']) }}")
411+
== """foo.* EXCEPT ("bar")"""
412+
)
413+
assert (
414+
context.render("{{ dbt_utils.star(from='foo', except=['bar', 'baz']) }}")
415+
== """foo.* EXCEPT ("bar", "baz")"""
416+
)
417+
with pytest.raises(CompilationError):
418+
context.render("{{ dbt_utils.star(from='foo', prefix='pre') }}")
419+
with pytest.raises(CompilationError):
420+
context.render("{{ dbt_utils.star(from='foo', suffix='suf') }}")

0 commit comments

Comments
 (0)