(* 
    OCaml XenStore Daemon.
    Copyright (C) 2008 Patrick Colp University of British Columbia

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

let domxs_init id =
  let port = Eventchan.bind_interdomain id (Os.get_xenbus_port ()) in
  let interface = Os.map_xenbus port in
  let connection = new Connection.connection interface in
  Eventchan.notify port;
  new Domain.domain id connection

let domain_entry_change domains domain_entry =
  if (domain_entry.Transaction.entries > 0)
  then
    for i = 1 to domain_entry.Transaction.entries do
      domains#entry_incr domain_entry.Transaction.id
    done
  else if (domain_entry.Transaction.entries < 0)
  then
    for i = domain_entry.Transaction.entries to (- 1) do
      domains#entry_decr domain_entry.Transaction.id
    done

class xenstored options store =
object(self)
  val m_domains = new Domain.domains
  val m_options : Option.t = options
  val m_permissions = new Permission.permissions
  val m_transactions = new Transaction.transactions store
  val m_store = store
  val m_watches = new Watch.watches
  val mutable m_virq_port = Constants.null_file_descr
  initializer m_permissions#set [ (Permission.string_of_permission (Permission.make Permission.NONE 0)) ] store Store.root_path; self#initialise_store
  method private store = m_store
  method add_domain domain =
    self#domains#add domain;
    Trace.create domain#id "connection"
  method add_watch (domain : Domain.domain) watch =
    if not (Domain.is_unprivileged domain) || self#watches#num_watches_for_domain domain#id < self#options.Option.quota_num_watches_per_domain
    then self#watches#add watch
    else raise (Constants.Xs_error (Constants.E2BIG, "Xenstored.xenstored#add_watch", "Too many watches"))
  method commit transaction =
    try
      List.iter (domain_entry_change self#domains) (self#transactions#domain_entries transaction);
      Transaction.fire_watches self#watches (self#transactions#commit transaction);
      true
    with _ -> false
  method domain_entry_count transaction (domain_id : int) =
    let entries = try self#domains#entry_count transaction.Transaction.domain_id with Not_found -> 0 in
    try
      let transaction_entries = (List.find (fun entry -> entry.Transaction.id = transaction.Transaction.domain_id) (self#transactions#domain_entries transaction)).Transaction.entries in
      transaction_entries + entries
    with Not_found -> entries
  method domain_entry_decr store transaction path =
    let domain_id = (List.hd (self#permissions#get store path)).Permission.domain_id in
    if Domain.is_unprivileged_id domain_id then
      if transaction.Transaction.transaction_id <> 0l
      then self#transactions#domain_entry_decr transaction domain_id
      else self#domains#entry_decr domain_id
  method domain_entry_incr store transaction path =
    let domain_id = (List.hd (self#permissions#get store path)).Permission.domain_id in
    if Domain.is_unprivileged_id domain_id then
      if transaction.Transaction.transaction_id <> 0l
      then (
        self#transactions#domain_entry_incr transaction domain_id;
        let entry_count = (List.find (fun entry -> entry.Transaction.id = domain_id) (self#transactions#domain_entries transaction)).Transaction.entries in
        let entry_count_current = try self#domains#entry_count domain_id with Not_found -> 0 in
        if entry_count + entry_count_current > self#options.Option.quota_num_entries_per_domain
        then (
          self#transactions#domain_entry_decr transaction domain_id;
          raise (Constants.Xs_error (Constants.EINVAL, "Xenstored.xenstored#domain_entry_incr", path))
        )
      )
      else (
        self#domains#entry_incr domain_id;
        let entry_count = self#domains#entry_count domain_id in
        if entry_count > self#options.Option.quota_num_entries_per_domain
        then (
          self#domains#entry_decr domain_id;
          raise (Constants.Xs_error (Constants.EINVAL, "Xenstored.xenstored#domain_entry_incr", path))
        )
      )
  method domains = m_domains
  method initialise_domains =
    if self#options.Option.domain_init
    then (
      if Domain.xc_handle = Constants.null_file_descr then Utils.barf_perror "Failed to open connection to hypervisor\n";
      Eventchan.init ();
      let dom0 =
        if self#options.Option.separate_domain
        then (
          self#add_domain (domxs_init (Os.get_domxs_id ()));
          Domain.domu_init 0 (Os.get_dom0_port ()) (Os.get_dom0_mfn ()) true
        )
        else domxs_init 0 in
      m_virq_port <- Eventchan.bind_virq Constants.virq_dom_exc;
      if m_virq_port = Constants.null_file_descr then Utils.barf_perror "Failed to bind to domain exception virq port\n";
      self#add_domain dom0;
      Eventchan.get_channel ()
    )
    else Constants.null_file_descr
  method initialise_store =
    let path = Store.root_path ^ "tool" ^ Store.dividor_str ^ "xenstored" in
    self#store#create_node path;
    self#permissions#add self#store path 0
  method new_transaction domain store =
    if not (Domain.is_unprivileged domain) || self#transactions#num_transactions_for_domain domain#id < self#options.Option.quota_max_transaction
    then self#transactions#new_transaction domain store
    else raise (Constants.Xs_error (Constants.ENOSPC, "Xenstored.xenstored#new_transaction", "Too many transactions"))
  method options = m_options
  method permissions = m_permissions
  method remove_domain domain =
    self#domains#remove domain;
    Trace.destroy domain#id "connection";
    self#watches#remove_watches domain;
    self#transactions#remove_domain domain
  method transactions = m_transactions
  method virq_port = m_virq_port
  method watches = m_watches
end
