/* **********************************************************
 * Copyright 1998 VMware, Inc.  All rights reserved. -- VMware Confidential
 * **********************************************************/

/*
 * memtrack.c --
 *
 *      Utility module for tracking memory allocated and/or 
 *      locked by the monitor. 
 *     
 *      
 */

#ifdef linux
/* Must come before any kernel header file --hpreg */
#   include "driver-config.h"

#   include <linux/string.h> /* memset() in the kernel */
#elif WINNT_DDK
#   undef PAGE_SIZE        /* Redefined in ntddk.h, and we use that defn. */
#   undef PAGE_SHIFT
#   include <ntddk.h>
#else 
#   include <string.h>
#endif

#include "vmware.h"
#include "x86.h"
#include "memtrack.h"
#include "vmx86.h"
#include "hostif.h"
#include "vm_atomic.h"


/*
 * MemTrack - We track memory using a two-level page table
 * like structure. We allocate the first level in the granularity 
 * of pages and pack as many MemTrackEntires as possible in a page. 
 * We currently handle 4 gig's worth of pages.  -- edward
 */

#define MAX_TRACE_PAGE_DIR \
   CEILING(1 << (32 - PAGE_SHIFT), PAGE_SIZE / sizeof (MemTrackEntry))

#define HASH_TABLE_SIZE    16384
#define HASH_TABLE_ENTRIES_PER_PAGE (PAGE_SIZE / sizeof (void *))
#define HASH_TABLE_PAGES   (HASH_TABLE_SIZE / HASH_TABLE_ENTRIES_PER_PAGE)

typedef struct MemTrack {
   int numEntriesPerPage;    /* Number of entry per page */
   int numPages;             /* Number of pages stored in the tracker */
   char *mapPages[MAX_TRACE_PAGE_DIR];
   MemTrackEntry **vpnHashTablePages[HASH_TABLE_PAGES]; /* VPN to entry hashtable */
#ifdef MEMTRACK_MPN_LOOKUP
   MemTrackEntry **mpnHashTablePages[HASH_TABLE_PAGES]; /* MPN to entry hashtable */
#endif
} MemTrack;

static INLINE MemTrackEntry **
HASH_VPN(MemTrackEntry **hashTablePages[HASH_TABLE_PAGES], VPN vpn) 
{
   unsigned hash   = (unsigned)vpn % HASH_TABLE_SIZE;
   unsigned page   = hash / HASH_TABLE_ENTRIES_PER_PAGE;
   unsigned offset = hash % HASH_TABLE_ENTRIES_PER_PAGE;
   return hashTablePages[page] + offset;
}

#ifdef MEMTRACK_MPN_LOOKUP
#define HASH_MPN(_ht, _mpn) HASH_VPN((_ht), (_mpn))
#endif

/*
 *----------------------------------------------------------------------
 *
 * MemTrack_Init --
 *
 *      Allocate and initialize the memory tracker.
 *
 * Results:
 *      Handle used to access the memtracker.
 *
 * Side effects:
 *      memory allocation.
 *
 *----------------------------------------------------------------------
 */

void *
MemTrack_Init(void)
{
   MemTrack *st;
   int i;

   st = HostIF_AllocKernelMem(sizeof *st, FALSE);
   if (st == NULL) { 
      Warning("MemTrack_Init failed\n");
      return NULL;
   }
   
   memset(st, 0, sizeof *st);
   
   for (i = 0; i < HASH_TABLE_PAGES; i++) {
      VA vpnPage = (VA)HostIF_AllocPage();

      if (vpnPage == 0) { 
         Warning("MemTrack_Init failed on hashTablePages %d\n",i);
         // LEAK
         return NULL;
      }
      st->vpnHashTablePages[i] = (MemTrackEntry **)vpnPage;
      memset((char*)vpnPage, 0, PAGE_SIZE);
   }

#ifdef MEMTRACK_MPN_LOOKUP
   for (i = 0; i < HASH_TABLE_PAGES; i++) {
      VA mpnPage = (VA)HostIF_AllocPage();

      if (mpnPage == 0) { 
         Warning("MemTrack_Init failed on hashTablePages %d\n",i);
         // LEAK
         return NULL;
      }
      st->mpnHashTablePages[i] = (MemTrackEntry **)mpnPage;
      memset((char*)mpnPage, 0, PAGE_SIZE);
   }
#endif 
   
   ASSERT(sizeof (MemTrackEntry) <= PAGE_SIZE);
   st->numEntriesPerPage = PAGE_SIZE / sizeof (MemTrackEntry);
   st->numPages = 0;
   memset(st->mapPages, 0, sizeof st->mapPages);

   return (void *)st;
}


/*
 *----------------------------------------------------------------------
 *
 * MemTrack_Add --
 *
 *      Add the specified address to the memory tracker. 
 *
 * Results:
 *      A pointer to the allocated element, or NULL on error.
 *
 * Side effects:
 *      memory allocation.
 *
 *----------------------------------------------------------------------
 */

MemTrackEntry *
MemTrack_Add(void *s,  // memtracker handler
             VPN vpn,  // VPN of entry
             MPN mpn)  // MPN of entry
{
  MemTrack *st = (MemTrack *) s;
  int ind = st->numPages;
  int dirPage = ind / st->numEntriesPerPage;
  int offset = ind % st->numEntriesPerPage;
  MemTrackEntry *ptr;
  MemTrackEntry **head;
  /*
   * Allocate a directory page if needed.
   */
  if (dirPage >= MAX_TRACE_PAGE_DIR) return NULL;

  if (st->mapPages[dirPage] == NULL) {
     st->mapPages[dirPage] = HostIF_AllocPage();
     if (st->mapPages[dirPage] == NULL) return NULL;
  }

  /*
   * Store the page in the tracker 
   */
  ptr = (MemTrackEntry *) st->mapPages[dirPage] + offset;
  ptr->vpn = vpn;  
  ptr->mpn = mpn;  

  /* Add entry in the VPN hash table. */
  head = HASH_VPN(st->vpnHashTablePages,vpn);
  ptr->vpnHashChain = *head;
  *head = ptr;

#ifdef MEMTRACK_MPN_LOOKUP
  /* Add entry in the MPN hash table. */
  head = HASH_MPN(st->mpnHashTablePages,mpn);
  ptr->mpnHashChain = *head;
  *head = ptr;
#endif

  st->numPages++;
  
  return ptr;
}

/*
 *----------------------------------------------------------------------
 *
 * MemTrack_LookupVPN --
 *
 *      Lookup the specified address in the memory tracker. 
 *
 * Results:
 *      A pointer to the allocated element, or NULL if not there.
 *
 *----------------------------------------------------------------------
 */

MemTrackEntry *
MemTrack_LookupVPN(void *s,  // memtracker handler
		   VPN vpn)  // Value to lookup
{
  MemTrack *st = (MemTrack *) s;
  MemTrackEntry *nextPtr;
    
  nextPtr = *HASH_VPN(st->vpnHashTablePages,vpn);
  while (nextPtr != NULL) {
    if (nextPtr->vpn == vpn) {
      return nextPtr;
    }
    nextPtr = nextPtr->vpnHashChain;
  }
  return NULL;
}


#ifdef MEMTRACK_MPN_LOOKUP
/*
 *----------------------------------------------------------------------
 *
 * MemTrack_LookupMPN --
 *
 *      Find a page given its mpn in the memory tracker. 
 *
 * Results:
 *      A pointer to the allocated element, or NULL if not there.
 *
 *----------------------------------------------------------------------
 */
MemTrackEntry *
MemTrack_LookupMPN(void *s, MPN mpn)
{
  MemTrack *st = (MemTrack *) s;
  MemTrackEntry *nextPtr;
    
  nextPtr = *HASH_MPN(st->mpnHashTablePages,mpn);
  while (nextPtr != NULL) {
    if (nextPtr->mpn == mpn) {
      return nextPtr;
    }
    nextPtr = nextPtr->mpnHashChain;
  }
  return NULL;
}
#endif


/*
 *----------------------------------------------------------------------
 *
 * MemTrack_Scan --
 *
 *      Scan the stored pages using a provide search function. 
 *
 * Results:
 *      The return result of the searchFunc
 *
 * Side effects:
 *      None
 *
 *----------------------------------------------------------------------
 */

void *
MemTrack_Scan(void *s,  // Handle for memtracker
        void *arg,  // Argument to searchFunc
        void *(*searchFunc)(void *arg, MemTrackEntry *ptr))
{
  MemTrack *st = (MemTrack *) s;
  int ind;
  int dirPage;
  MemTrackEntry *ptr;
  void *ret;

  /* 
   * Call searchFunc with every address stored in the tracker
   * or until it returns a non-zero value.
   */
  for (ind = 0; ind < st->numPages; ind++) {
    dirPage = ind / st->numEntriesPerPage;
    ptr = (MemTrackEntry *) st->mapPages[dirPage] + 
          (ind % st->numEntriesPerPage);
    ret = searchFunc(arg, ptr);
    if (ret) return ret;
  }
  return NULL;
}




/*
 *----------------------------------------------------------------------
 *
 * MemTrack_Cleanup --
 *
 *      Cleanup all resources allocated for the memtracker. For
 *      all pages in the tracker call the user provided free function.
 *
 * Results:
 *      Number of pages in the tracker. 
 *
 * Side effects:
 *      Memory free
 *
 *----------------------------------------------------------------------
 */

int 
MemTrack_Cleanup(void *s,   // Mem tracker handle
                 void (*CleanUp)(void *clientData,MemTrackEntry*),
                 void *clientData)   // free functions.
{
  MemTrack *st = (MemTrack *) s;
  int ind;
  int dirPage;
  MemTrackEntry *ptr;
  int count = 0;
  int i;

  for (ind = 0; ind < st->numPages; ind++) {
    dirPage = ind / st->numEntriesPerPage;
    ptr = (MemTrackEntry *) st->mapPages[dirPage] + 
          (ind % st->numEntriesPerPage);
    CleanUp(clientData,ptr);
    count++;
  }
  /*
   * Free any directory page number and then free the tracker
   * structure itself 
   */
  for (dirPage = 0; dirPage < MAX_TRACE_PAGE_DIR; dirPage++) {
     if (st->mapPages[dirPage]) { 
        HostIF_FreePage(st->mapPages[dirPage]);
        st->mapPages[dirPage] = NULL;
     }
  }
  
  for (i=0;i<HASH_TABLE_PAGES;i++) { 
     if (st->vpnHashTablePages[i]) {
        HostIF_FreePage(st->vpnHashTablePages[i]);
        st->vpnHashTablePages[i] = NULL;
     }
#ifdef MEMTRACK_MPN_LOOKUP
     if (st->mpnHashTablePages[i]) {
        HostIF_FreePage(st->mpnHashTablePages[i]);
        st->mpnHashTablePages[i] = NULL;
     }
#endif
  }

  HostIF_FreeKernelMem(st);
  return count;
}

