fix WHERE behaviour

main
Michael 2025-01-30 17:03:28 +01:00
parent 4ac0ac96eb
commit e98915f3f1
1 changed files with 26 additions and 17 deletions

View File

@ -37,6 +37,13 @@ class BaseType(ABC):
def TIME(self):
return self.time
def WHERE(self,other):
if isinstance(other, ListType):
return ListType([self for b_item in other.value if b_item.value==True])
if isinstance(other, BoolType) and other:
return self
return NullType()
class NumType(BaseType):
def __init__(self, value: (int, float)):
@ -77,31 +84,31 @@ class NumType(BaseType):
def __gt__(self, other):
if isinstance(other, NumType):
return BoolType(self.value > other.value)
return NullType
return NullType()
def __lt__(self, other):
if isinstance(other, NumType):
return BoolType(self.value < other.value)
return NullType
return NullType()
def __ge__(self, other):
if isinstance(other, NumType):
return BoolType(self.value >= other.value)
if isinstance(other, ListType):
return ListType([self >= item for item in other.value])
return NullType
return NullType()
def __le__(self, other):
if isinstance(other, NumType):
return BoolType(self.value <= other.value)
if isinstance(other, ListType):
return ListType([self <= item for item in other.value])
return NullType
return NullType()
def __pow__(self, other):
if isinstance(other, NumType):
return NumType(self.value ** other.value)
return NullType
return NullType()
def __neg__(self):
return NumType(-self.value)
@ -156,7 +163,7 @@ class BoolType(BaseType):
def __and__(self,other):
if isinstance(other, BoolType):
return BoolType(self.value and other.value)
return NullType
return NullType()
def NOT(self):
return ~self
@ -174,52 +181,54 @@ class ListType(BaseType):
def __gt__(self, other):
if isinstance(other, NumType):
return ListType([item > other for item in self.value])
return NullType
return NullType()
def __lt__(self, other):
if isinstance(other, NumType):
return ListType([item < other for item in self.value])
return NullType
return NullType()
def __ge__(self, other):
if isinstance(other, NumType):
return ListType([item >= other for item in self.value])
if isinstance(other, ListType) and len(self.value) == len(other.value):
return ListType([a_item >= b_item for a_item, b_item in zip(self.value, other.value)])
return NullType
return NullType()
def __le__(self, other):
if isinstance(other, NumType):
return ListType([item <= other for item in self.value])
if isinstance(other, ListType) and len(self.value) == len(other.value):
return ListType([a_item <= b_item for a_item, b_item in zip(self.value, other.value)])
return NullType
return NullType()
def __and__(self,other):
if isinstance(other, BoolType) or isinstance(other, NumType):
return ListType([item and other for item in self.value])
if isinstance(other, ListType) and len(self.value) == len(other.value):
return ListType([a_item and b_item for a_item, b_item in zip(self.value, other.value)])
return NullType
return NullType()
def __or__(self,other):
if isinstance(other, BoolType) or isinstance(other, NumType):
return ListType([item or other for item in self.value])
if isinstance(other, ListType) and len(self.value) == len(other.value):
return ListType([a_item or b_item for a_item, b_item in zip(self.value, other.value)])
return NullType
return NullType()
def __truediv__(self,other):
if isinstance(other, NumType):
return ListType([item / other for item in self.value])
if isinstance(other, ListType) and len(self.value) == len(other.value):
return ListType([a_item / b_item for a_item, b_item in zip(self.value, other.value)])
return NullType
return NullType()
def WHERE(self,other):
if isinstance(other, ListType) and len(self.value) == len(other.value):
return ListType([a_item for a_item, b_item in zip(self.value, other.value) if b_item])
return NullType
return ListType([a_item for a_item, b_item in zip(self.value, other.value) if b_item.value==True])
if isinstance(other, BoolType) and other:
return self
return NullType()
def IS(self,a,invert=False):
if a is NumType:
@ -262,12 +271,12 @@ class DateType(BaseType):
def __gt__(self, other):
if isinstance(other, DateType):
return BoolType(self.value > other.value)
return NullType
return NullType()
def __lt__(self, other):
if isinstance(other, DateType):
return BoolType(self.value < other.value)
return NullType
return NullType()
def __str__(self):
return self.value.strftime('%Y-%m-%dT%H:%M:%S')