Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
2 changes: 1 addition & 1 deletion sqlalchemy_bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from .version import __version__

from .base import BigQueryDialect, dialect
from .base import BigQueryDialect, dialect, TimePartitioning
from ._types import (
ARRAY,
BIGNUMERIC,
Expand Down
85 changes: 83 additions & 2 deletions sqlalchemy_bigquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import random
import operator
import uuid
import typing

from google import auth
import google.api_core.exceptions
Expand All @@ -36,6 +37,7 @@
import sqlalchemy.sql.sqltypes
import sqlalchemy.sql.type_api
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.exc import NoSuchColumnError
from sqlalchemy import util
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.compiler import (
Expand Down Expand Up @@ -64,6 +66,61 @@
TABLE_VALUED_ALIAS_ALIASES = "bigquery_table_valued_alias_aliases"


class TimePartitioningType(object):
DAY = "DAY"
HOUR = "HOUR"
MONTH = "MONTH"
YEAR = "YEAR"


class TimePartitioning(sqlalchemy.sql.sqltypes.TypeEngine):
__visit_name__ = "TimePartitioning"

TimePartitioningType = TimePartitioningType

def __init__(self, type_: str, field: str, expiration_ms: int = 31536000000, require_partition_filter: bool = True):
self.type_ = type_
self.field = field
self.expiration_ms = expiration_ms
self.require_partition_filter = require_partition_filter

def __bool__(self) -> bool:
return self.field is not None

def __str__(self) -> str:
return f"DATE_TRUNC(`{self.field}`, {self.type_})"

def __iter__(self) -> typing.Iterable[str]:
yield "partition_expiration_days = {}".format(self.expiration)
if self.require:
yield "require_partition_filter = true"

def __call__(self, preparer: IdentifierPreparer) -> str:
return f"DATE_TRUNC({preparer.quote(self.field)}, {self.type_})"

@property
def expiration(self) -> int:
return self.expiration_ms//(1000*60*60*24)

@property
def require(self) -> bool:
return self.require_partition_filter

@property
def oprtions(self) -> list[str]:
return list(self)

@classmethod
def frombigquery_time_partitioning(cls, partition) -> 'TimePartitioning':
if partition:
return cls(
partition.type_,
partition.field,
partition.expiration_ms,
partition.require_partition_filter)
return cls(TimePartitioningType.DAY, None)


def assert_(cond, message="Assertion failed"): # pragma: NO COVER
if not cond:
raise AssertionError(message)
Expand Down Expand Up @@ -649,6 +706,30 @@ def post_create_table(self, table):
bq_opts = table.dialect_options["bigquery"]
opts = []

text = ""
if "time_partitioning" in bq_opts and bq_opts["time_partitioning"]:
partition = bq_opts["time_partitioning"]
if partition.field not in table.c:
raise NoSuchColumnError(partition.field)
text += "\nPARTITION BY {}".format(partition)
opts.extend(partition)

if "clustering_fields" in bq_opts and bq_opts["clustering_fields"]:
cluster = bq_opts["clustering_fields"]
for n in cluster:
if n not in table.c:
raise NoSuchColumnError(n)
text += "\nCLUSTER BY ({})".format(
",".join(
[
self.preparer.format_column(
table.c[n], use_table=False, use_schema=False
)
for n in cluster
]
)
)

if ("description" in bq_opts) or table.comment:
description = process_string_literal(
bq_opts.get("description", table.comment)
Expand All @@ -663,9 +744,9 @@ def post_create_table(self, table):
)

if opts:
return "\nOPTIONS({})".format(", ".join(opts))
return text + "\nOPTIONS({})".format(", ".join(opts))

return ""
return text

def visit_set_table_comment(self, create):
table_name = self.preparer.format_table(create.element)
Expand Down