fix WHERE behaviour
parent
4ac0ac96eb
commit
e98915f3f1
|
@ -37,6 +37,13 @@ class BaseType(ABC):
|
|||
def TIME(self):
|
||||
return self.time
|
||||
|
||||
def WHERE(self,other):
|
||||
if isinstance(other, ListType):
|
||||
return ListType([self for b_item in other.value if b_item.value==True])
|
||||
if isinstance(other, BoolType) and other:
|
||||
return self
|
||||
return NullType()
|
||||
|
||||
|
||||
class NumType(BaseType):
|
||||
def __init__(self, value: (int, float)):
|
||||
|
@ -77,31 +84,31 @@ class NumType(BaseType):
|
|||
def __gt__(self, other):
|
||||
if isinstance(other, NumType):
|
||||
return BoolType(self.value > other.value)
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, NumType):
|
||||
return BoolType(self.value < other.value)
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __ge__(self, other):
|
||||
if isinstance(other, NumType):
|
||||
return BoolType(self.value >= other.value)
|
||||
if isinstance(other, ListType):
|
||||
return ListType([self >= item for item in other.value])
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __le__(self, other):
|
||||
if isinstance(other, NumType):
|
||||
return BoolType(self.value <= other.value)
|
||||
if isinstance(other, ListType):
|
||||
return ListType([self <= item for item in other.value])
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __pow__(self, other):
|
||||
if isinstance(other, NumType):
|
||||
return NumType(self.value ** other.value)
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __neg__(self):
|
||||
return NumType(-self.value)
|
||||
|
@ -156,7 +163,7 @@ class BoolType(BaseType):
|
|||
def __and__(self,other):
|
||||
if isinstance(other, BoolType):
|
||||
return BoolType(self.value and other.value)
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def NOT(self):
|
||||
return ~self
|
||||
|
@ -174,52 +181,54 @@ class ListType(BaseType):
|
|||
def __gt__(self, other):
|
||||
if isinstance(other, NumType):
|
||||
return ListType([item > other for item in self.value])
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, NumType):
|
||||
return ListType([item < other for item in self.value])
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __ge__(self, other):
|
||||
if isinstance(other, NumType):
|
||||
return ListType([item >= other for item in self.value])
|
||||
if isinstance(other, ListType) and len(self.value) == len(other.value):
|
||||
return ListType([a_item >= b_item for a_item, b_item in zip(self.value, other.value)])
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __le__(self, other):
|
||||
if isinstance(other, NumType):
|
||||
return ListType([item <= other for item in self.value])
|
||||
if isinstance(other, ListType) and len(self.value) == len(other.value):
|
||||
return ListType([a_item <= b_item for a_item, b_item in zip(self.value, other.value)])
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __and__(self,other):
|
||||
if isinstance(other, BoolType) or isinstance(other, NumType):
|
||||
return ListType([item and other for item in self.value])
|
||||
if isinstance(other, ListType) and len(self.value) == len(other.value):
|
||||
return ListType([a_item and b_item for a_item, b_item in zip(self.value, other.value)])
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __or__(self,other):
|
||||
if isinstance(other, BoolType) or isinstance(other, NumType):
|
||||
return ListType([item or other for item in self.value])
|
||||
if isinstance(other, ListType) and len(self.value) == len(other.value):
|
||||
return ListType([a_item or b_item for a_item, b_item in zip(self.value, other.value)])
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __truediv__(self,other):
|
||||
if isinstance(other, NumType):
|
||||
return ListType([item / other for item in self.value])
|
||||
if isinstance(other, ListType) and len(self.value) == len(other.value):
|
||||
return ListType([a_item / b_item for a_item, b_item in zip(self.value, other.value)])
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def WHERE(self,other):
|
||||
if isinstance(other, ListType) and len(self.value) == len(other.value):
|
||||
return ListType([a_item for a_item, b_item in zip(self.value, other.value) if b_item])
|
||||
return NullType
|
||||
return ListType([a_item for a_item, b_item in zip(self.value, other.value) if b_item.value==True])
|
||||
if isinstance(other, BoolType) and other:
|
||||
return self
|
||||
return NullType()
|
||||
|
||||
def IS(self,a,invert=False):
|
||||
if a is NumType:
|
||||
|
@ -262,12 +271,12 @@ class DateType(BaseType):
|
|||
def __gt__(self, other):
|
||||
if isinstance(other, DateType):
|
||||
return BoolType(self.value > other.value)
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, DateType):
|
||||
return BoolType(self.value < other.value)
|
||||
return NullType
|
||||
return NullType()
|
||||
|
||||
def __str__(self):
|
||||
return self.value.strftime('%Y-%m-%dT%H:%M:%S')
|
||||
|
|
Loading…
Reference in New Issue