Fix a bug with sqrt(x<0).

portnov [2008-12-14 08:14:30]
Fix a bug with sqrt(x<0).
Filename
data.py
diff --git a/data.py b/data.py
index 43b198e..a8b704f 100644
--- a/data.py
+++ b/data.py
@@ -25,16 +25,18 @@ def trace(f):
 def is_number(x):
     return isinstance(x,float) or isinstance(x,int) or isinstance(x,Const)

-def is_plus(x):
-    return isinstance(x,Expr) and (x.op == '+')
+def is_nothing(x):
+    return isinstance(x,NothingClass)

-def is_mul(x):
-    return isinstance(x,Expr) and (x.op == '*')
+def obj(x):
+    if isinstance(x,float) or isinstance(x,int):
+        return Const(x)
+    return x

 def sqrt_(x):
     if is_number(x):
         if x < 0:
-            raise ValueError, "sqrt(x<0)!"
+            return Nothing
         return sqrt(x)
     elif isinstance(x,Union):
         n = Union()
@@ -44,6 +46,30 @@ def sqrt_(x):
     else:
         return Expr('sqrt',x,None)

+class NothingClass(object):
+    def __repr__(self):
+        return "Nothing"
+
+    def __add__(self,e):
+        return Nothing
+
+    def __mul__(self,e):
+        return Nothing
+
+    def __sub__(self,e):
+        return Nothing
+
+    def __sub__(self,e):
+        return Nothing
+
+    def __div__(self,e):
+        return Nothing
+
+    def __pow__(self,e):
+        return Nothing
+
+Nothing = NothingClass()
+
 class Unknown(object):
     def __init__(self,name='',add=True):
         self.name = name
@@ -69,6 +95,8 @@ class Const(float):
     def __add__(self,e):
         if is_number(e):
             return float.__add__(self,e)
+        if is_nothing(e):
+            return Nothing
         if isinstance(e,Union):
             n = Union()
             for x in e:
@@ -79,11 +107,15 @@ class Const(float):
     def __sub__(self,e):
         if is_number(e):
             return float.__sub__(self,e)
+        if is_nothing(e):
+            return Nothing
         return self + Expr('*',-1,e)

     def __mul__(self,e):
         if is_number(e):
             return float.__mul__(self,e)
+        if is_nothing(e):
+            return Nothing
         if isinstance(e,Union):
             n = Union()
             for x in e:
@@ -252,7 +284,7 @@ class Monoid(object):
 #                     print u, "not in self.unknowns"
                     us[u] = e.unknowns[u]
             c = self.coef * e.coef
-            return Monoid(Const(c), unknowns=us)
+            return Monoid(obj(c), unknowns=us)
         else:
             lst = [self*x for x in e]
             return Sum(lst)
@@ -340,12 +372,12 @@ class Union(set):
             n = Union()
             for x in self:
                 for y in e:
-                    n.add(Const(x)+Const(y))
+                    n.add(obj(x)+obj(y))
             return n
         else:
             n = Union()
             for x in self:
-                n.add(Const(x)+e)
+                n.add(obj(x)+e)
             return n

     def __neg__(self):
@@ -362,7 +394,7 @@ class Union(set):
             n = Union()
             for x in self:
                 for y in e:
-                    n.add(Const(x)*y)
+                    n.add(obj(x)*y)
             return n
         else:
             n = Union()
@@ -381,7 +413,7 @@ class Union(set):
             n = Union()
             for x in self:
                 for y in e:
-                    n.add(Const(x)/y)
+                    n.add(obj(x)/y)
             return n
         else:
             n = Union()
@@ -517,7 +549,6 @@ def solve_simple(eq):
 def solve_quadratic_single(eq):
     global unknowns
     u,a,b,c = single_quadratic_view(to_sum(eq.expr))
-    print "||", c, type(c)
     D = b*b - Const(4)*(a*c)
     if is_number(D) and D<0:
         x = Union()
@@ -539,11 +570,10 @@ def solve_quadratic_single(eq):

 def subst_knowns(eq):
     global unknowns
-    print ">>",eq
     r = Sum([])
     for m in eq.expr:
         us = UD({})
-        cc = Const(m.coef)
+        cc = obj(m.coef)
         for u in m.unknowns:
             if unknowns[u] is not None:
                 cc = cc * unknowns[u]**m.unknowns[u]
ViewGit