My Weblog

Blog about programming and math

Recursive Matrix Mulitplication in Haskell ( First step to Strassen’s Algorithm )

Strassen’s algorithm states that we can multiply two 2 X 2 matrices in seven multiplication rather than eight. I wrote a quick recursive solution using eight multiplication and it’s not very efficient because the real problem is Haskell’s List laziness. The best library could be repa library or dph and I will try to implement it using these libraries. This implementation assumes that you are providing the matrix which is already power of 2.Please let me know if you found any bug or suggestions to improve this code.

import Data.List


data Matrix a = Matrix { get :: [[a]] } deriving ( Show )

instance ( Num a ) => Num ( Matrix a ) where
    ( Matrix xs ) + ( Matrix ys ) = Matrix ( zipWith ( \ x y -> zipWith ( + ) x y ) xs ys )
    ( Matrix xs ) - ( Matrix ys ) = Matrix ( zipWith ( \ x y -> zipWith ( - ) x y ) xs ys )
    ( Matrix xs ) * ( Matrix ys ) = Matrix ( map ( \x ->  map ( sum.zipWith (*) x  ) ( transpose  ys ) ) xs )
    abs ( Matrix xs ) = undefined
    signum ( Matrix xs ) = undefined
    fromInteger _  = undefined


recurMult :: ( Num a ) => Int -> Int -> Matrix a -> Matrix a -> Matrix a
recurMult n lev xs ys
     | lev >= 2 = xs * ys -- not splitting matrix more than 2 levels 
     | otherwise =  Matrix ret where
              n' = div n 2
              ( a , b ) = ( get xs , get ys )
              ( a_u , a_l ) = splitAt n' a
              ( b_u , b_l ) = splitAt n' b
              ( a11 , a12 ) = ( Matrix { get =  map ( fst . splitAt n' ) a_u }  , Matrix { get =  map ( snd . splitAt n' ) a_u } )
              ( a21 , a22 ) = ( Matrix { get =  map ( fst . splitAt n' ) a_l }  , Matrix { get =  map ( snd . splitAt n' ) a_l } )
              ( b11 , b12 ) = ( Matrix { get =  map ( fst . splitAt n' ) b_u }  , Matrix { get =  map ( snd . splitAt n' ) b_u } )
              ( b21 , b22 ) = ( Matrix { get =  map ( fst . splitAt n' ) b_l }  , Matrix { get =  map ( snd . splitAt n' ) b_l } )
              Matrix c11 = recurMult n' ( lev + 1 ) a11  b11  +  recurMult n' ( lev + 1 ) a12  b21
              Matrix c12 = recurMult n' ( lev + 1 ) a11  b12  +  recurMult n' ( lev + 1 ) a12  b22
              Matrix c21 = recurMult n' ( lev + 1 ) a21  b11  +  recurMult n' ( lev + 1 ) a22  b21
              Matrix c22 = recurMult n' ( lev + 1 ) a21  b12  +  recurMult n' ( lev + 1 ) a22  b22

              ret = ( zipWith ( ++ ) c11 c12 ) ++ ( zipWith ( ++ ) c21 c22 )


tempMult :: ( Num a ) =>  [ [ a ] ] -> [ [ a ] ] ->  [ [ a ] ]
tempMult xs ys = get $ recurMult ( length xs ) 0 ( Matrix xs ) ( Matrix ys )

*Main> tempMult ( [  [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] ] )  ( [  [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] , [ 1 , 2 , 3 , 4 ] ] )
[[10,20,30,40],[10,20,30,40],[10,20,30,40],[10,20,30,40]]

Advertisements

December 1, 2012 - Posted by | Programming

No comments yet.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: