#include "mpi.h"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#define N 8


void Print(char *msg, MPI_Comm comm)
{
  int myid;

  MPI_Comm_rank(comm ,&myid);

  if(myid == 0)
    printf("%s", msg);
}


void PrintLocalMat(int m, int n, float mat[][n])
{
  int i,j;
  for (j=0; j < n; j++)
    {
      for (i=0; i < m; i++)
	printf("%7.4f ", mat[i][j]);
  
      printf("\n");
    }
}

void PrintMat(int m, int n, float mat[][n], MPI_Comm comm)
{
  int numprocs, myid; 
  MPI_Comm_size(comm, &numprocs);
  MPI_Comm_rank(comm, &myid);

  float tmp[m][n];

  if ( myid == 0)
    {
      PrintLocalMat(m, n, mat);

      int k;
      MPI_Status status;
      for(k=1; k < numprocs; k++)
	{
	  MPI_Recv(tmp, m*n, MPI_INT, k, 0, comm, &status);
	  PrintLocalMat(m, n, tmp); 

	}
    }
  else
    MPI_Send(mat, m*n, MPI_INT, 0, 0, comm);
}

void TransposeBlock(int n, float mat[n][n])
{

  int i,j;
  float tmp;

    for (i=0; i < n; i++)
      for (j=i+1; j < n; j++)
	{
	  tmp = mat[i][j];
	  mat[i][j] = mat[j][i];
	  mat[j][i] = tmp;
	}
}

void TransposeBlocks(int m, int n, float mat[][n])
{

  int i;
 
  for (i=0; i < m; i += n)
    TransposeBlock(n, &mat[i]);
  
}

int main(int argc, char *argv[])
{
    int numprocs, myid, n;

    MPI_Init(&argc,&argv);
    MPI_Comm_size(MPI_COMM_WORLD,&numprocs);
    MPI_Comm_rank(MPI_COMM_WORLD,&myid);

    n = N / numprocs;

    float A[N][n];
    float B[N][n];
    float C[N][n];

    int i,j,k, J, tmp;
    
    for (i=0; i < N; i++)
      for (j=0, J = myid*n; j < n; j++, J++)
	{
	  tmp = (J+1)*(i + 1);
	  A[i][j] = tmp;
	  B[i][j] = 1.0 /tmp;
	}
  
    Print("A:\n", MPI_COMM_WORLD); 
    PrintMat(N, n, A, MPI_COMM_WORLD);
    
    Print("B:\n", MPI_COMM_WORLD); 
    PrintMat(N, n, B, MPI_COMM_WORLD);

    float column[N][n];
     
	 
    for (i=0; i < N; i+=n)
      {      

	MPI_Allgather(&B[i], n*n, MPI_FLOAT, column,  n*n, MPI_FLOAT, MPI_COMM_WORLD);    
	TransposeBlocks(N, n, column);


	for (int j1=0; j1 < n; j1++)
	  for (j=0; j < n; j++)
	    {
	      C[i+j1][j] = 0;
	      for (k=0; k<N; k++)
		C[i+j1][j] += A[k][j] * column[k][j1];
	    }

      }
	 

    Print("C:\n", MPI_COMM_WORLD); 
    PrintMat(N, n, C, MPI_COMM_WORLD);
    

    MPI_Finalize();
    return 0;
}
