Attempt to generate SQL by models definitions.

portnov [2009-06-19 19:20:30]
Attempt to generate SQL by models definitions.
Filename
Blog/Blog
Blog/Config.hs
Blog/Models.hs
Blog/blog.db
Framework/API.hs
Framework/Models.hs
Framework/SQL.hs
Framework/Storage.hs
Framework/Urls.hs
diff --git a/Blog/Blog b/Blog/Blog
index 985960b..7a08d03 100755
Binary files a/Blog/Blog and b/Blog/Blog differ
diff --git a/Blog/Config.hs b/Blog/Config.hs
index 8c8b11d..be94fc0 100644
--- a/Blog/Config.hs
+++ b/Blog/Config.hs
@@ -5,8 +5,10 @@ import Framework.Types

 params = HP { docdir = "",
               hLog = stdout,
-              dbDriver = "sqlite3",
-              dbPath = "blog.db",
+--               dbDriver = "sqlite3",
+              dbDriver = "psql",
+--               dbPath = "blog.db",
+              dbPath = "host=rtfm-server password=31415",
               cacheDriver = "filesystem",
               cachePath = "tmp/",
               sessionsDriver = "files",
diff --git a/Blog/Models.hs b/Blog/Models.hs
index c595d15..0bb8b82 100644
--- a/Blog/Models.hs
+++ b/Blog/Models.hs
@@ -12,11 +12,13 @@ import Framework.Models

 postModel = Model {
     mName = "post",
+    mTable = "posts",
     mFields = ["id" ::: IntegerColumn,
                "dt" ::: CurrentDateColumn,
                "title" ::: StringColumn,
                "body"  ::: StringColumn],
-    mCached = ["ncomments" ::: IntegerColumn]
+    mCached = ["ncomments" ::: IntegerColumn],
+    mChildren = [(commentModel,"id","pid")]
     }

 postid = show.(transformInt 1 id)
@@ -38,12 +40,14 @@ addNComments post n = setCached postModel "ncomments" IntegerColumn n

 commentModel = Model {
     mName = "comment",
+    mTable = "comments",
     mFields = ["id" ::: IntegerColumn,
                "pid" ::: IntegerColumn,
                "dt" ::: CurrentDateColumn,
                "author" ::: StringColumn,
                "body" ::: StringColumn ],
-    mCached = []
+    mCached = [],
+    mChildren = []
     }

 commentId = show.(transformInt 1 id)
diff --git a/Blog/blog.db b/Blog/blog.db
index 462ccc7..3a9aa16 100644
Binary files a/Blog/blog.db and b/Blog/blog.db differ
diff --git a/Framework/API.hs b/Framework/API.hs
index ca73c6d..5101ec7 100644
--- a/Framework/API.hs
+++ b/Framework/API.hs
@@ -62,7 +62,7 @@ commit ac = Storage.commit (dbconnection ac)
 -- * Storage/SQL API

 queryListSQL :: ActionConfig -> SQL.Query -> [HDBC.SqlValue] -> IO [[HDBC.SqlValue]]
-queryListSQL ac q params = Storage.query (dbconnection ac) (SQL.sql q) params
+queryListSQL ac q params = Storage.query (dbconnection ac) (trace (SQL.sql q) (SQL.sql q)) params

 queryListSQL' :: ActionConfig -> SQL.Query -> [HDBC.SqlValue] -> IO [[HDBC.SqlValue]]
 queryListSQL' ac q params = Storage.query' (dbconnection ac) (SQL.sql q) params
diff --git a/Framework/Models.hs b/Framework/Models.hs
index 769f942..0171143 100644
--- a/Framework/Models.hs
+++ b/Framework/Models.hs
@@ -21,11 +21,17 @@ defaultValue CurrentDateColumn = SqlString "current_timestamp"

 data Model = Model {
     mName :: String,
+    mTable :: String,
     mFields :: [ModelField],
-    mCached :: [ModelField]
+    mCached :: [ModelField],
+    mChildren :: [(Model,String,String)]
     }
     deriving (Show)

+cModel (m,_,_) = m
+cParent (_,f,_) = f
+cChild (_,_,c) = c
+
 data ModelField = String ::: ColumnType
                 | FilledField String ColumnType SqlValue
     deriving (Show)
diff --git a/Framework/SQL.hs b/Framework/SQL.hs
index b0fb257..62f1c0e 100644
--- a/Framework/SQL.hs
+++ b/Framework/SQL.hs
@@ -9,8 +9,8 @@ module Framework.SQL
      sgroup, order,
      restrict,
      limit,
-     insertQ,
-     updateQ,
+     countChildren,
+--      insertQ, updateQ,
      aggregate, count
     ) where

@@ -18,9 +18,14 @@ import Data.List
 import Database.HDBC
 import qualified Data.Convertible.Base as CD

+import Framework.Models
+
+data Tables = TableList [SQLTable] | TableJoin [SQLTable]
+    deriving (Eq,Show)
+
 data Query = Query {
     qFields :: [SQLField],
-    qTables :: [SQLTable],
+    qTables :: Tables,
     qWhere  :: SQLCondition,
     qOrder  :: [SQLOrder],
     qGroup  :: [String],
@@ -37,20 +42,13 @@ data Query = Query {
     deriving (Eq,Show)

 data SQLField = QField String
-              | QFn String String String
-              | AsF String String
+              | QFn String String
     deriving (Eq,Show)

 fieldname (QField n) = n
-fieldname (QFn _ _ n) = n
-fieldname (AsF _ n) = n
+fieldname (QFn _ n) = n

-data SQLTable = QTable String
-              | AsT String String
-    deriving (Eq,Show)
-
-tablename (QTable n) = n
-tablename (AsT _ n) = n
+type SQLTable = String

 data SQLCondition =
       NoCondition
@@ -62,7 +60,6 @@ data SQLCondition =
     | SQLCondition :|: SQLCondition
     deriving (Eq,Show)

--- TODO: support ... WHERE x.field...
 type Selector = String

 data SQLOrder = Asceding String | Desceding String
@@ -73,40 +70,39 @@ class SQLFragment s where

 instance SQLFragment SQLCondition where
     sqlFragment NoCondition = ""
-    sqlFragment (x :==: y) = sqlFPair "=" x y
-    sqlFragment (x :/=: y) = sqlFPair "!=" x y
-    sqlFragment (x :>: y) = sqlFPair ">" x y
-    sqlFragment (x :<: y) = sqlFPair "<" x y
+    sqlFragment (x :==: y) = sqlLift "=" x y
+    sqlFragment (x :/=: y) = sqlLift "!=" x y
+    sqlFragment (x :>: y) = sqlLift ">" x y
+    sqlFragment (x :<: y) = sqlLift "<" x y
     sqlFragment (x :&: y) = "("++(sqlFPair " AND " x y)++")"
     sqlFragment (x :|: y) = "("++(sqlFPair " OR " x y)++")"

 sqlFPair :: (SQLFragment f) => String -> f -> f -> String
 sqlFPair op x y = (sqlFragment x)++op++(sqlFragment y)

+sqlLift op x y = x++op++y
+
 instance SQLFragment SQLField where
     sqlFragment (QField n) = n
-    sqlFragment (QFn a fn f) = fn++"("++f++") AS "++a
-    sqlFragment (AsF a n) = n++" AS "++a
-
-instance SQLFragment SQLTable where
-    sqlFragment (QTable n) = n
-    sqlFragment (AsT a n) = n++" "++a
-
-instance SQLFragment Selector where
-    sqlFragment s = s
+    sqlFragment (QFn fn f) = fn++"("++f++")"

 instance SQLFragment SQLOrder where
     sqlFragment (Asceding o) = o++" ASC"
     sqlFragment (Desceding o) = o++" DESC"

-
 sql :: Query -> String
-sql (Query fields tables whre order group ls) = "SELECT "++(sqlList fields)++" FROM "++(sqlList tables)++other
-    where other = wpart++opart++gpart++lpart
-          wpart = if whre==NoCondition then "" else " WHERE "++(sqlFragment whre)
-          opart = if null order then "" else " ORDER BY "++(sqlList order)
-          gpart = if null group then "" else " GROUP BY "++(commas group)
-          lpart | Just (x,y) <- ls = " LIMIT "++(show x)++", "++(show y)
+sql (Query fields tables whre order group ls) = "SELECT "++(sqlList fields)++" FROM "++tlist++other
+    where other = wpart++gpart++opart++lpart
+          tlist | TableList ts <- tables = commas ts
+                | TableJoin ts <- tables = sqlJoin ts
+          wpart | whre == NoCondition = ""
+                | TableList _ <- tables = " WHERE "++(sqlFragment whre)
+                | TableJoin _ <- tables = " ON "++(sqlFragment whre)
+          opart | null order = ""
+                | otherwise  = " ORDER BY "++(sqlList order)
+          gpart | null group = ""
+                | otherwise  = " GROUP BY "++(commas group)
+          lpart | Just (x,y) <- ls = " OFFSET "++(show x)++" LIMIT "++(show y)
                 | otherwise = ""
 sql (InsertQuery table fields values) = "INSERT INTO "++table++" ("++(commas fields)++") VALUES ("++(commas values)++")"
 sql (UpdateQuery table fields values whre) = "UPDATE "++table++" SET "++eqs++wpart
@@ -115,21 +111,77 @@ sql (UpdateQuery table fields values whre) = "UPDATE "++table++" SET "++eqs++wpa

 commas = concat . intersperse ", "
 sqlList = commas.map sqlFragment
+sqlJoin = concat . (intersperse " LEFT JOIN ")

-aggregate q fn = q {qFields=(map (liftF fn) (qFields q))}
+aggregate q fn = q {qFields=(onlyLast (liftF fn) (qFields q))}

-liftF fn (QField name) = QFn (fname++fn) fn name
-    where fname | name=="*" = "all"
-                | otherwise = name
-liftF fn (AsF a name) = QFn a fn name
-liftF fn (QFn a _ name) = QFn a fn name
+onlyLast f lst = (init lst)++[(f $ last lst)]
+
+liftF fn (QField name) = QFn fn name
+liftF fn (QFn _ name) = QFn fn name

 count = flip aggregate "count"

 allFields = [QField "*"]

-table t = Query allFields [QTable t] NoCondition [] [] Nothing
-tables ts = Query allFields (map QTable ts) NoCondition [] [] Nothing
+tableR t = Query allFields (TableList [t]) NoCondition [] [] Nothing
+tablesR ts = Query allFields (TableList ts) NoCondition [] [] Nothing
+
+object = Model {
+    mName = "object",
+    mTable = "objects",
+    mFields = [ "id" ::: IntegerColumn,
+                "dt" ::: CurrentDateColumn,
+                "name" ::: StringColumn,
+                "value" ::: StringColumn ],
+    mCached = [],
+    mChildren = [(sub,"id","pid")]
+    }
+
+sub = Model {
+    mName = "child",
+    mTable = "children",
+    mFields = [ "id" ::: IntegerColumn,
+                "pid" ::: IntegerColumn,
+                "body" ::: StringColumn ],
+    mCached = [],
+    mChildren = []
+    }
+
+table m = tableR (mTable m)
+
+insertM m = InsertQuery (mTable m) (map fieldName insfields) temps
+    where insfields = filter notid $ mFields m
+          notid s = not ("id" `isSuffixOf` (fieldName s))
+          temps = map (\f -> if (fieldType f)==CurrentDateColumn
+                               then "current_timestamp"
+                               else "?") insfields
+
+updateM m cond = UpdateQuery (mTable m) (map fieldName updfields) temps cond
+    where updfields = filter normal $ mFields m
+          normal s = (not ("id" `isSuffixOf` (fieldName s))) && ((fieldType s)/=CurrentDateColumn)
+          temps = replicate (length updfields) "?"
+
+countChildren m ord = count $ setFields fs $ ((table m) `joinT` childTable)
+        `restrict` ((childTable++"."++childId) :==: parentField)
+        `sgroup` (parentField++", "++ordField) `order` (Asceding ordField)
+    where childTable = mTable $ cModel $ head $ mChildren m
+          parentId = cParent $ head $ mChildren m
+          childId = cChild $ head $ mChildren m
+          parentField = mTable m ++"."++ parentId
+          ordField = (mTable m)++"."++ord
+          fs = parent++child
+          parent = map (\f -> if "id" `isSuffixOf` (fieldName f)
+                                then QField $ (mTable m)++"."++(fieldName f)
+                                else if (fieldName f)==ord
+                                       then QField $ (mTable m)++"."++(fieldName f)
+                                       else QFn "max" $ (mTable m)++"."++(fieldName f)) $ mFields m
+          child = [QField childTable]
+
+setFields fs q = q { qFields = fs }
+
+joinT q@(Query {qTables = tables}) tbl | TableList ts <- tables = q { qTables = TableJoin (ts++[tbl]) }
+                                       | TableJoin ts <- tables = q { qTables = TableJoin (ts++[tbl]) }

 select q fs = q {qFields= (map QField fs)}
 onlyFields = select
@@ -144,9 +196,9 @@ sgroup q grp = q {qGroup = (qGroup q)++[grp]}

 limit q pair = q {qLimits = Just pair}

-insertQ (Query fields tables _ _ _ _) values = InsertQuery (tablename $ head tables) (map fieldname fields) values
-
-updateQ (Query fields tables _ _ _ _) cond values = UpdateQuery (tablename $ head tables) (map fieldname fields) values cond
+-- insertQ (Query fields tables _ _ _ _) values = InsertQuery (tablename $ head tables) (map fieldname fields) values
+--
+-- updateQ (Query fields tables _ _ _ _) cond values = UpdateQuery (tablename $ head tables) (map fieldname fields) values cond

 -- myquery = (table "users") `select` ["name","passwd"] `order` (Asceding "name")

diff --git a/Framework/Storage.hs b/Framework/Storage.hs
index f2d8ed4..9173be1 100644
--- a/Framework/Storage.hs
+++ b/Framework/Storage.hs
@@ -10,6 +10,8 @@ module Framework.Storage


 import qualified Database.HDBC.Sqlite3 as Sqlite3
+import qualified Database.HDBC.MySQL as MySQL
+import qualified Database.HDBC.PostgreSQL as PostgreSQL
 import qualified Database.HDBC as D

 import Framework.Types
@@ -19,6 +21,7 @@ data DBConnection = forall c. D.IConnection c => DBC c

 connect :: String -> String -> IO DBConnection
 connect "sqlite3" file = DBC `fmap` (Sqlite3.connectSqlite3 file)
+connect "psql"  str    = DBC `fmap` (PostgreSQL.connectPostgreSQL str)

 connect' :: HttpActionParams -> IO DBConnection
 connect' (HP {dbDriver, dbPath}) = connect dbDriver dbPath
diff --git a/Framework/Urls.hs b/Framework/Urls.hs
index 83f827b..5ab2ed9 100644
--- a/Framework/Urls.hs
+++ b/Framework/Urls.hs
@@ -114,7 +114,7 @@ httpAddGetVar rq name value = urlencode (map packHeader pairs')
     where pairs' = update name value pairs
           pairs = decodePairs (uriQuery $ reqURI rq)

-decodePairs s = map (both tryDecode) (trace (show pairs) pairs)
+decodePairs s = map (both tryDecode) pairs
     where pairs = queryToArguments $ replaceplus s
           both f (x,y) = (f x, f y)
           tryDecode s | isUTF8Encoded s = decodeString s
ViewGit