Add caching.

portnov [2008-05-19 08:22:23]
Add caching.
Remove bogus sM2 statistics.
Filename
stats.py
diff --git a/stats.py b/stats.py
index 6e140f4..107d494 100755
--- a/stats.py
+++ b/stats.py
@@ -10,6 +10,27 @@ import gtk
 TMPPLOT = "/tmp/stats.plot"
 TMPPNG = "/tmp/stats.png"

+cache = dict()
+
+def cached(func):
+  def wrapper(x):
+    i = id(x)
+    if func in cache:
+      if i in cache[func]:
+        return cache[func][i]
+      else:
+        t = func(x)
+        cache[func][i] = t
+        return t
+    else:
+      t = func(x)
+      cache[func] = {i: t}
+      return t
+  wrapper.__name__ = func.__name__
+  return wrapper
+
+# cached = lambda f: f
+
 class T(tuple):
   def __init__(self,*list):
     tuple.__init__(self,list)
@@ -35,6 +56,7 @@ class T(tuple):
   def __repr__(self):
     return " ".join(["%.3f" % d for d in self])

+@cached
 def summ(data):
   r = None
   for t in data:
@@ -44,6 +66,7 @@ def summ(data):
       r += t
   return r

+@cached
 def uniq(data):
   r = dict()
   for n in data:
@@ -53,6 +76,7 @@ def uniq(data):
       r[n] = 1
   return r

+@cached
 def probs(data):
   r = []
   for i in range(len(data[0])):
@@ -63,28 +87,19 @@ def probs(data):
     r.append(u)
   return r

+@cached
 def mean(data):
   return summ(data)/len(data)

+@cached
 def disp(data):
   m = mean(data)
   return summ([ (x-m)**2 for x in data]) / len(data)

+@cached
 def sigma(data):
   return T(map(sqrt,disp(data)))

-def stabmean(data):
-  p = probs(data)
-  r = []
-  for ps in p:
-    p2 = dict()
-    for n in ps:
-      p2[n] = ps[n]**2
-    s = sum(p2.values())
-    m = summ([x*p2[x] for x in ps])/s
-    r.append(m)
-  return T(r)
-
 def write_script(data,ms,ss):
   colors = ["red","blue","green","orange","darkblue","redorange","teal","oceanblue", "yelloworange", "purple"]
   def write_data(f,data):
@@ -172,9 +187,8 @@ class GUI(object):
     self.vbox.pack_start(self.text)
     self.window.add(self.vbox)

-  def markup(self,M,D,s,sM):
+  def markup(self,M,D,s):
     global VARS
-    print VARS

     r = ""
     for i in range(len(M)):
@@ -182,15 +196,13 @@ class GUI(object):
         var = VARS[i]
       except IndexError:
         var = "X%d" % (i+1)
-      r += """<b>M%s:</b> %.3f  <b>D%s:</b> %.3f
-<b>sM%s:</b> %.3f  <b>σ%s:</b> %.3f
-
-""" % (var,M[i],var,D[i],var,sM[i],var,s[i])
+      r += """<b>M%s:</b> %.3f  <b>D%s:</b> %.3f   <b>σ%s:</b> %.3f
+""" % (var,M[i],var,D[i],var,s[i])
     return r

-  def show(self,M,D,s,sM):
+  def show(self,M,D,s):
     self.img.set_from_file(TMPPNG)
-    self.text.set_markup(self.markup(M,D,s,sM))
+    self.text.set_markup(self.markup(M,D,s))
     self.window.show_all()
     gtk.main()

@@ -219,15 +231,14 @@ if __name__ == "__main__":
   M = mean(data)
   D = disp(data)
   s = sigma(data)
-  sM = stabmean(data)

   write_script(data,M,s)
   os.system("ploticus -maxproclines 20000 -png %s -o %s" % (TMPPLOT,TMPPNG))

   gui = GUI()
-  gui.show(M,D,s,sM)
+  gui.show(M,D,s)

   os.remove(TMPPLOT)
   os.remove(TMPPNG)

-  print "MX: %s, DX: %s, σX: %s, sMX: %s" % (M,D,s,sM)
+  print "MX: %s\nDX: %s\nσX: %s\n" % (M,D,s)
ViewGit