Now we can solve simple equations (linear and quadratic).

portnov [2008-12-14 06:39:08]
Now we can solve simple equations (linear and quadratic).
Filename
data.py
syntax.py
test.g
test.py
diff --git a/data.py b/data.py
index a58068b..43b198e 100644
--- a/data.py
+++ b/data.py
@@ -1,4 +1,5 @@

+from math import sqrt
 from copy import deepcopy as copy

 COND_PARAL = 1
@@ -30,6 +31,19 @@ def is_plus(x):
 def is_mul(x):
     return isinstance(x,Expr) and (x.op == '*')

+def sqrt_(x):
+    if is_number(x):
+        if x < 0:
+            raise ValueError, "sqrt(x<0)!"
+        return sqrt(x)
+    elif isinstance(x,Union):
+        n = Union()
+        for e in x:
+            n.add(sqrt_(e))
+        return n
+    else:
+        return Expr('sqrt',x,None)
+
 class Unknown(object):
     def __init__(self,name='',add=True):
         self.name = name
@@ -55,6 +69,11 @@ class Const(float):
     def __add__(self,e):
         if is_number(e):
             return float.__add__(self,e)
+        if isinstance(e,Union):
+            n = Union()
+            for x in e:
+                n.add(self+x)
+            return n
         return Expr('+',self,e)

     def __sub__(self,e):
@@ -65,6 +84,11 @@ class Const(float):
     def __mul__(self,e):
         if is_number(e):
             return float.__mul__(self,e)
+        if isinstance(e,Union):
+            n = Union()
+            for x in e:
+                n.add(self*x)
+            return n
         return Expr('*',self,e)

 class Line(object):
@@ -228,10 +252,16 @@ class Monoid(object):
 #                     print u, "not in self.unknowns"
                     us[u] = e.unknowns[u]
             c = self.coef * e.coef
-            return Monoid(c, unknowns=us)
+            return Monoid(Const(c), unknowns=us)
         else:
             lst = [self*x for x in e]
             return Sum(lst)
+
+    def is_numeric(self):
+        if self.unknowns.keys():
+            return False
+        else:
+            return True

 class Expr(object):
     def __init__(self,op,x,y):
@@ -240,6 +270,8 @@ class Expr(object):
         self.y = y

     def __repr__(self):
+        if self.op == 'sqrt':
+            return "sqrt(%s)" % self.x
         if self.op=='*':
             if isinstance(self.x, Unknown):
                 return repr(self.y)+'*'+repr(self.x)
@@ -270,7 +302,25 @@ class Expr(object):
                 return Expr('+', self.x*e, self.y*e)
         return Expr('*',self,e)

+    def __div__(self,e):
+        if is_number(e):
+            if self.op == '*':
+                if is_number(self.x):
+                    return Expr('/',self.x*e,self.y)
+                elif is_number(self.y):
+                    return Expr('/',self.x,self.y*e)
+            elif self.op == '+':
+                return Expr('+', self.x/e, self.y/e)
+        return Expr('/',self,e)
+
 def to_sum(expr):
+    if isinstance(expr,Monoid):
+        return Sum([expr])
+    if isinstance(expr,Sum):
+        r = Sum([])
+        for m in expr:
+            r = r + to_sum(m)
+        return r
     if is_number(expr):
         return Sum([Monoid(expr)])
     if isinstance(expr,Unknown):
@@ -280,11 +330,72 @@ def to_sum(expr):
     if expr.op == '*':
         return to_sum(expr.x) * to_sum(expr.y)

+class Union(set):
+    def __repr__(self):
+        lst = list(self)
+        return repr(lst)
+
+    def __add__(self,e):
+        if isinstance(e,Union):
+            n = Union()
+            for x in self:
+                for y in e:
+                    n.add(Const(x)+Const(y))
+            return n
+        else:
+            n = Union()
+            for x in self:
+                n.add(Const(x)+e)
+            return n
+
+    def __neg__(self):
+        n = Union()
+        for x in self:
+            n.add(Const(-1)*x)
+        return n
+
+    def __sub__(self,e):
+        return self + Const(-1)*e
+
+    def __mul__(self,e):
+        if isinstance(e,Union):
+            n = Union()
+            for x in self:
+                for y in e:
+                    n.add(Const(x)*y)
+            return n
+        else:
+            n = Union()
+            for x in self:
+                n.add(x*e)
+            return n
+
+    def __pow__(self,e):
+        n = Union()
+        for x in self:
+            n.add(x**e)
+        return n
+
+    def __div__(self,e):
+        if isinstance(e,Union):
+            n = Union()
+            for x in self:
+                for y in e:
+                    n.add(Const(x)/y)
+            return n
+        else:
+            n = Union()
+            for x in self:
+                n.add(x/e)
+            return n
+
+
+
 def collect_nums(sum):
     s = 0
     r = Sum([])
     for m in sum:
-        if not m.unknowns.keys():
+        if m.is_numeric():
             s += m.coef
         else:
             r = r + m
@@ -306,10 +417,10 @@ def collect_xy(sum):
 def collect_vars(sum):
     def key(m):
         return "".join([("%s%s" % (u,n)) for u,n in m.unknowns.iteritems()])
-    dd = {None: 0.0}
+    dd = {None: Const(0)}
     du = {}
     for m in sum:
-        if not m.unknowns.keys():
+        if m.is_numeric():
             dd[None] = dd[None] + m.coef
             continue
         k = key(m)
@@ -339,6 +450,109 @@ class Equation(object):
     def __repr__(self):
         return '%s = 0' % self.expr

+# @trace
+def max_power(s):
+    max = 0
+    for m in s:
+        pow = sum([m.unknowns[u] for u in m.unknowns if m.unknowns[u]])
+        if pow > max:
+            max = pow
+    return max
+
+# @trace
+def is_linear(eq):
+    return max_power(to_sum(eq.expr)) == 1
+
+# @trace
+def nr_unknowns(eq):
+    def nr_unknowns_sum(sum):
+        us = []
+        for m in sum:
+            for u in m.unknowns:
+                if not u in us:
+                    us.append(u)
+        return len(us)
+
+    return nr_unknowns_sum(to_sum(eq.expr))
+
+# @trace
+def is_simple(eq):
+    return is_linear(eq) and (nr_unknowns(eq) == 1)
+
+def is_quadratic_single(eq):
+    return (max_power(to_sum(eq.expr))==2) and (nr_unknowns(eq)==1)
+
+def linear_view(sum):
+    dt = {}
+    for m in sum:
+        if m.is_numeric():
+            dt[None] = m.coef
+        else:
+            ks = m.unknowns.keys()
+            dt[ks[0]] = m.coef
+    return list(sorted(dt.iteritems()))
+
+def single_quadratic_view(sum):
+    for m in sum:
+        if m.is_numeric():
+            c = m.coef
+        else:
+            u = m.unknowns.keys()[0]
+            pw = m.unknowns[u]
+            if pw == 1:
+                b = m.coef
+            elif pw == 2:
+                a = m.coef
+    return u,a,b,c
+
+def solve_simple(eq):
+    global unknowns
+    n,x = linear_view(to_sum(eq.expr))
+    _,b = n
+    u,a = x
+    value = -b/a
+    unknowns[u] = value
+    return u,value
+
+def solve_quadratic_single(eq):
+    global unknowns
+    u,a,b,c = single_quadratic_view(to_sum(eq.expr))
+    print "||", c, type(c)
+    D = b*b - Const(4)*(a*c)
+    if is_number(D) and D<0:
+        x = Union()
+    elif is_number(D) and D==0:
+        x = -b/(2*a)
+    else:
+        x1 = (-b - sqrt_(D))/(2*a)
+        x2 = (-b + sqrt_(D))/(2*a)
+        if isinstance(x1,Union) and isinstance(x2,Union):
+            x = Union(list(x1)+list(x2))
+        elif isinstance(x1,Union):
+            x = Union(list(x1)+[x2])
+        elif isinstance(x2,Union):
+            x = Union(list(x2)+[x1])
+        else:
+            x = Union([x1,x2])
+    unknowns[u] = x
+    return u,x
+
+def subst_knowns(eq):
+    global unknowns
+    print ">>",eq
+    r = Sum([])
+    for m in eq.expr:
+        us = UD({})
+        cc = Const(m.coef)
+        for u in m.unknowns:
+            if unknowns[u] is not None:
+                cc = cc * unknowns[u]**m.unknowns[u]
+            else:
+                us[u] = m.unknowns[u]
+#         print eq,us
+        r = r + Monoid(cc,unknowns=us)
+    return Equation(collect_vars(r))
+
 def equation_paral(l1,l2):
 #     print l1.A,'*',l2.B, '-', l2.A,'*',l1.B
     e = Expr('+',Expr('*',l1.A,l2.B), Expr('*',-1,Expr('*',l2.A,l1.B)))
@@ -379,7 +593,7 @@ class Condition(object):

     def __repr__(self):
         return "C:[%s %s %s]" % (self.x, self._kinds[self.kind], self.y)
-
+
     def equation(self):
         if self.kind == COND_PARAL:
             a,b = self.x.A, self.x.B
diff --git a/syntax.py b/syntax.py
index d1bd88a..8d113cd 100644
--- a/syntax.py
+++ b/syntax.py
@@ -58,6 +58,13 @@ if __name__ == '__main__':
             s = s.strip()
             ob = yacc.parse(s)
             if isinstance(ob,Condition):
-                print "(%s)\t%s" % (ob.number, ob.equation())
+                eq = subst_knowns(ob.equation())
+                print "(%s)\t%s" % (ob.number, eq)
+                if is_simple(eq):
+                    u,v = solve_simple(eq)
+                    print "\tSolved:\t%s = %s" % (u,v)
+                elif is_quadratic_single(eq):
+                    u,v = solve_quadratic_single(eq)
+                    print "\tSolved:\t%s = %s" % (u,v)


diff --git a/test.g b/test.g
new file mode 100644
index 0000000..0eb8552
--- /dev/null
+++ b/test.g
@@ -0,0 +1,9 @@
+A(1,1)
+B(3,3)
+C(4,1)
+AC || BD
+CD = 4
+E(6,2)
+EF || AC
+DF = 4
+
diff --git a/test.py b/test.py
index 98eeb94..6c77aae 100644
--- a/test.py
+++ b/test.py
@@ -2,6 +2,7 @@ from data import *

 K=Const
 U=Unknown
+Eq=Equation

 A=Point('A',K(1),K(3))
 B=Point('B',K(2),K(6))
@@ -10,12 +11,28 @@ D=point('D')
 l1 = line_by_two_points(A,B)
 l2 = line_by_two_points(C,D)

+print "Equations"
+ee = Eq(K(3)*U('x')-K(2))
+print ee
+print "Linear:", is_linear(ee), "Unknowns:", nr_unknowns(ee), "Simple:", is_simple(ee)
+print "X:", solve_simple(ee)
+ee2 = Eq(K(3)*U('x')*U('y')-K(2)*U('z'))
+print ee2
+print "Linear:", is_linear(ee2), "Unknowns:", nr_unknowns(ee2), "Simple:", is_simple(ee2)
+ee3 = Eq(K(3)*U('x')-K(2)*U('y'))
+print ee3
+print "Linear:", is_linear(ee3), "Unknowns:", nr_unknowns(ee3), "Simple:", is_simple(ee3)
+
+print
+
+print "Monoids"
 m1 = Monoid(3.0, unknowns=UD({U('x'):1, U('y'):2}))
 m2 = Monoid(2.0, unknowns=UD({U('x'):2}))
 print m1, '*', m2, '=', m1*m2
 print m1, '+', m2, '=', m1+m2
 print

+print "Lines"
 print to_sum(l1.equation().expr)
 print to_sum(l2.equation().expr)
 print collect_xy(to_sum(l2.equation().expr))
ViewGit