Fixed #2443 -- Added DurationField.

A field for storing periods of time - modeled in Python by timedelta. It
is stored in the native interval data type on PostgreSQL and as a bigint
of microseconds on other backends.

Also includes significant changes to the internals of time related maths
in expressions, including the removal of DateModifierNode.

Thanks to Tim and Josh in particular for reviews.
This commit is contained in:
Marc Tamlyn 2014-07-24 13:57:24 +01:00
parent a3d96bee36
commit 57554442fe
26 changed files with 524 additions and 138 deletions

View file

@ -34,12 +34,12 @@ class CombinableMixin(object):
BITOR = '|'
def _combine(self, other, connector, reversed, node=None):
if isinstance(other, datetime.timedelta):
return DateModifierNode(self, connector, other)
if not hasattr(other, 'resolve_expression'):
# everything must be resolvable to an expression
other = Value(other)
if isinstance(other, datetime.timedelta):
other = DurationValue(other, output_field=fields.DurationField())
else:
other = Value(other)
if reversed:
return Expression(other, connector, self)
@ -333,6 +333,18 @@ class Expression(ExpressionNode):
self.lhs, self.rhs = exprs
def as_sql(self, compiler, connection):
try:
lhs_output = self.lhs.output_field
except FieldError:
lhs_output = None
try:
rhs_output = self.rhs.output_field
except FieldError:
rhs_output = None
if (not connection.features.has_native_duration_field and
((lhs_output and lhs_output.get_internal_type() == 'DurationField')
or (rhs_output and rhs_output.get_internal_type() == 'DurationField'))):
return DurationExpression(self.lhs, self.connector, self.rhs).as_sql(compiler, connection)
expressions = []
expression_params = []
sql, params = compiler.compile(self.lhs)
@ -354,45 +366,31 @@ class Expression(ExpressionNode):
return c
class DateModifierNode(Expression):
"""
Node that implements the following syntax:
filter(end_date__gt=F('start_date') + datetime.timedelta(days=3, seconds=200))
which translates into:
POSTGRES:
WHERE end_date > (start_date + INTERVAL '3 days 200 seconds')
MYSQL:
WHERE end_date > (start_date + INTERVAL '3 0:0:200:0' DAY_MICROSECOND)
ORACLE:
WHERE end_date > (start_date + INTERVAL '3 00:03:20.000000' DAY(1) TO SECOND(6))
SQLITE:
WHERE end_date > django_format_dtdelta(start_date, "+" "3", "200", "0")
(A custom function is used in order to preserve six digits of fractional
second information on sqlite, and to format both date and datetime values.)
Note that microsecond comparisons are not well supported with MySQL, since
MySQL does not store microsecond information.
Only adding and subtracting timedeltas is supported, attempts to use other
operations raise a TypeError.
"""
def __init__(self, lhs, connector, rhs):
if not isinstance(rhs, datetime.timedelta):
raise TypeError('rhs must be a timedelta.')
if connector not in (self.ADD, self.SUB):
raise TypeError('Connector must be + or -, not %s' % connector)
super(DateModifierNode, self).__init__(lhs, connector, Value(rhs))
class DurationExpression(Expression):
def compile(self, side, compiler, connection):
if not isinstance(side, DurationValue):
try:
output = side.output_field
except FieldError:
pass
if output.get_internal_type() == 'DurationField':
sql, params = compiler.compile(side)
return connection.ops.format_for_duration_arithmetic(sql), params
return compiler.compile(side)
def as_sql(self, compiler, connection):
timedelta = self.rhs.value
sql, params = compiler.compile(self.lhs)
if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0):
return sql, params
return connection.ops.date_interval_sql(sql, self.connector, timedelta), params
expressions = []
expression_params = []
sql, params = self.compile(self.lhs, compiler, connection)
expressions.append(sql)
expression_params.extend(params)
sql, params = self.compile(self.rhs, compiler, connection)
expressions.append(sql)
expression_params.extend(params)
# order of precedence
expression_wrapper = '(%s)'
sql = connection.ops.combine_duration_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
class F(CombinableMixin):
@ -488,6 +486,13 @@ class Value(ExpressionNode):
return '%s', [self.value]
class DurationValue(Value):
def as_sql(self, compiler, connection):
if connection.features.has_native_duration_field:
return super(DurationValue, self).as_sql(compiler, connection)
return connection.ops.date_interval_sql(self.value)
class Col(ExpressionNode):
def __init__(self, alias, target, source=None):
if source is None: