// CC0 Public Domain: http://creativecommons.org/publicdomain/zero/1.0/
#ifndef __SIMPLE_SIGNAL_HH__
#define __SIMPLE_SIGNAL_HH__

#include <unistd.h>
#include <assert.h>
#include <stdint.h>
#include <functional>
#include <vector>

namespace Simple {

namespace Lib {

/// ProtoSignal is the template implementation for callback list.
template<typename,typename> class ProtoSignal;   // undefined

/// CollectorInvocation invokes signal handlers differently depending on return type.
template<typename,typename> class CollectorInvocation;

/// CollectorLast returns the result of the last signal handler from a signal emission.
template<typename Result>
struct CollectorLast {
  typedef Result CollectorResult;
  explicit        CollectorLast ()              : last_() {}
  inline bool     operator()    (Result r)      { last_ = r; return true; }
  CollectorResult result        ()              { return last_; }
private:
  Result last_;
};

/// CollectorDefault implements the default signal handler collection behaviour.
template<typename Result>
struct CollectorDefault : CollectorLast<Result>
{};

/// CollectorDefault specialisation for signals with void return type.
template<>
struct CollectorDefault<void> {
  typedef void CollectorResult;
  void                  result     ()           {}
  inline bool           operator() (void)       { return true; }
};

/// CollectorInvocation specialisation for regular signals.
template<class Collector, class R, class... Args>
struct CollectorInvocation<Collector, R (Args...)> {
  inline bool
  invoke (Collector &collector, const std::function<R (Args...)> &cbf, Args... args)
  {
    return collector (cbf (args...));
  }
};

/// CollectorInvocation specialisation for signals with void return type.
template<class Collector, class... Args>
struct CollectorInvocation<Collector, void (Args...)> {
  inline bool
  invoke (Collector &collector, const std::function<void (Args...)> &cbf, Args... args)
  {
    cbf (args...); return collector();
  }
};

/// ProtoSignal template specialised for the callback signature and collector.
template<class Collector, class R, class... Args>
class ProtoSignal<R (Args...), Collector> : private CollectorInvocation<Collector, R (Args...)> {
protected:
  typedef std::function<R (Args...)> CbFunction;
  typedef typename CbFunction::result_type Result;
  typedef typename Collector::CollectorResult CollectorResult;
private:
  /// SignalLink implements a doubly-linked ring with ref-counted nodes containing the signal handlers.
  struct SignalLink {
    SignalLink *next, *prev;
    CbFunction  function;
    int         ref_count;
    explicit    SignalLink (const CbFunction &cbf) : next (NULL), prev (NULL), function (cbf), ref_count (1) {}
    /*dtor*/   ~SignalLink ()           { assert (ref_count == 0); }
    void        incref     ()           { ref_count += 1; assert (ref_count > 0); }
    void        decref     ()           { ref_count -= 1; if (!ref_count) delete this; else assert (ref_count > 0); }
    void
    unlink ()
    {
      function = NULL;
      if (next)
        next->prev = prev;
      if (prev)
        prev->next = next;
      decref();
      // leave intact ->next, ->prev for stale iterators
    }
    size_t
    add_before (const CbFunction &cb)
    {
      SignalLink *link = new SignalLink (cb);
      link->prev = prev; // link to last
      link->next = this;
      prev->next = link; // link from last
      prev = link;
      static_assert (sizeof (link) == sizeof (size_t), "sizeof size_t");
      return size_t (link);
    }
    bool
    deactivate (const CbFunction &cbf)
    {
      if (cbf == function)
        {
          function = NULL;      // deactivate static head
          return true;
        }
      for (SignalLink *link = this->next ? this->next : this; link != this; link = link->next)
        if (cbf == link->function)
          {
            link->unlink();     // deactivate and unlink sibling
            return true;
          }
      return false;
    }
    bool
    remove_sibling (size_t id)
    {
      for (SignalLink *link = this->next ? this->next : this; link != this; link = link->next)
        if (id == size_t (link))
          {
            link->unlink();     // deactivate and unlink sibling
            return true;
          }
      return false;
    }
  };
  SignalLink   *callback_ring_; // linked ring of callback nodes
  /*copy-ctor*/ ProtoSignal (const ProtoSignal&) = delete;
  ProtoSignal&  operator=   (const ProtoSignal&) = delete;
  void
  ensure_ring ()
  {
    if (!callback_ring_)
      {
        callback_ring_ = new SignalLink (CbFunction()); // ref_count = 1
        callback_ring_->incref(); // ref_count = 2, head of ring, can be deactivated but not removed
        callback_ring_->next = callback_ring_; // ring head initialization
        callback_ring_->prev = callback_ring_; // ring tail initialization
      }
  }
public:
  /// ProtoSignal constructor, connects default callback if non-NULL.
  ProtoSignal (const CbFunction &method) :
    callback_ring_ (NULL)
  {
    if (method != NULL)
      {
        ensure_ring();
        callback_ring_->function = method;
      }
  }
  /// ProtoSignal destructor releases all resources associated with this signal.
  ~ProtoSignal ()
  {
    if (callback_ring_)
      {
        while (callback_ring_->next != callback_ring_)
          callback_ring_->next->unlink();
        assert (callback_ring_->ref_count >= 2);
        callback_ring_->decref();
        callback_ring_->decref();
      }
  }
  /// Operator to add a new function or lambda as signal handler, returns a handler connection ID.
  size_t operator+= (const CbFunction &cb)      { ensure_ring(); return callback_ring_->add_before (cb); }
  /// Operator to remove a signal handler through it connection ID, returns if a handler was removed.
  bool   operator-= (size_t connection)         { return callback_ring_ ? callback_ring_->remove_sibling (connection) : false; }
  /// Emit a signal, i.e. invoke all its callbacks and collect return types with the Collector.
  CollectorResult
  emit (Args... args)
  {
    Collector collector;
    if (!callback_ring_)
      return collector.result();
    SignalLink *link = callback_ring_;
    link->incref();
    do
      {
        if (link->function != NULL)
          {
            const bool continue_emission = this->invoke (collector, link->function, args...);
            if (!continue_emission)
              break;
          }
        SignalLink *old = link;
        link = old->next;
        link->incref();
        old->decref();
      }
    while (link != callback_ring_);
    link->decref();
    return collector.result();
  }
};

} // Lib
// namespace Simple

/**
 * Signal is a template type providing an interface for arbitrary callback lists.
 * A signal type needs to be declared with the function signature of its callbacks,
 * and optionally a return result collector class type.
 * Signal callbacks can be added with operator+= to a signal and removed with operator-=, using
 * a callback connection ID return by operator+= as argument.
 * The callbacks of a signal are invoked with the emit() method and arguments according to the signature.
 * The result returned by emit() depends on the signal collector class. By default, the result of
 * the last callback is returned from emit(). Collectors can be implemented to accumulate callback
 * results or to halt a running emissions in correspondance to callback results.
 * The signal implementation is safe against recursion, so callbacks may be removed and
 * added during a signal emission and recursive emit() calls are also safe.
 * The overhead of an unused signal is intentionally kept very low, around the size of a single pointer.
 * Note that the Signal template types is non-copyable.
 */
template <typename SignalSignature, class Collector = Lib::CollectorDefault<typename std::function<SignalSignature>::result_type> >
struct Signal /*final*/ :
    Lib::ProtoSignal<SignalSignature, Collector>
{
  typedef Lib::ProtoSignal<SignalSignature, Collector> ProtoSignal;
  typedef typename ProtoSignal::CbFunction             CbFunction;
  /// Signal constructor, supports a default callback as argument.
  Signal (const CbFunction &method = CbFunction()) : ProtoSignal (method) {}
};

/// This function creates a std::function by binding @a object to the member function pointer @a method.
template<class Instance, class Class, class R, class... Args> std::function<R (Args...)>
slot (Instance &object, R (Class::*method) (Args...))
{
  return [&object, method] (Args... args) { return (object .* method) (args...); };
}

/// This function creates a std::function by binding @a object to the member function pointer @a method.
template<class Class, class R, class... Args> std::function<R (Args...)>
slot (Class *object, R (Class::*method) (Args...))
{
  return [object, method] (Args... args) { return (object ->* method) (args...); };
}

/// Keep signal emissions going while all handlers return !0 (true).
template<typename Result>
struct CollectorUntil0 {
  typedef Result CollectorResult;
  explicit                      CollectorUntil0 ()      : result_() {}
  const CollectorResult&        result          ()      { return result_; }
  inline bool
  operator() (Result r)
  {
    result_ = r;
    return result_ ? true : false;
  }
private:
  CollectorResult result_;
};

/// Keep signal emissions going while all handlers return 0 (false).
template<typename Result>
struct CollectorWhile0 {
  typedef Result CollectorResult;
  explicit                      CollectorWhile0 ()      : result_() {}
  const CollectorResult&        result          ()      { return result_; }
  inline bool
  operator() (Result r)
  {
    result_ = r;
    return result_ ? false : true;
  }
private:
  CollectorResult result_;
};

/// CollectorVector returns the result of the all signal handlers from a signal emission in a std::vector.
template<typename Result>
struct CollectorVector {
  typedef std::vector<Result> CollectorResult;
  const CollectorResult&        result ()       { return result_; }
  inline bool
  operator() (Result r)
  {
    result_.push_back (r);
    return true;
  }
private:
  CollectorResult result_;
};

} // Simple

#endif // __SIMPLE_SIGNAL_HH__



#ifndef DISABLE_TESTS

#include <string>
#include <stdarg.h>

static std::string string_printf (const char *format, ...) __attribute__ ((__format__ (__printf__, 1, 2)));
static std::string
string_printf (const char *format, ...)
{
  std::string result;
  char *str = NULL;
  va_list args;
  va_start (args, format);
  if (vasprintf (&str, format, args) >= 0)
    result = str;
  va_end (args);
  if (str)
    free (str);
  return result;
}

static uint64_t
timestamp_benchmark ()
{
  struct timespec tp = { 0, 0 };
  if (__builtin_expect (clock_gettime (CLOCK_MONOTONIC, &tp) < 0, 0))
    {
      perror ("failed in clock_gettime");
      exit (-1);
    }
  uint64_t stamp = tp.tv_sec * 1000000000ULL + tp.tv_nsec;
  return stamp;
}

struct TestCounter {
  static uint64_t get     ();
  static void     set     (uint64_t);
  static void     add2    (void*, uint64_t);
};

namespace { // Anon
void        (*test_counter_add2) (void*, uint64_t) = TestCounter::add2; // external symbol to prevent easy inlining
static uint64_t test_counter_var = 0;
} // Anon

class BasicSignalTests {
  static std::string accu;
  struct Foo {
    char
    foo_bool (float f, int i, std::string s)
    {
      accu += string_printf ("Foo: %.2f\n", f + i + s.size());
      return true;
    }
  };
  static char
  float_callback (float f, int, std::string)
  {
    accu += string_printf ("float: %.2f\n", f);
    return 0;
  }
public:
  static void
  run()
  {
    accu = "";
    Simple::Signal<char (float, int, std::string)> sig1;
    size_t id1 = sig1 += float_callback;
    size_t id2 = sig1 += [] (float, int i, std::string) { accu += string_printf ("int: %d\n", i); return 0; };
    size_t id3 = sig1 += [] (float, int, const std::string &s) { accu += string_printf ("string: %s\n", s.c_str()); return 0; };
    sig1.emit (.3, 4, "huhu");
    bool success;
    success = sig1 -= id1; assert (success == true);  success = sig1 -= id1; assert (success == false);
    success = sig1 -= id2; assert (success == true);  success = sig1 -= id3; assert (success == true);
    success = sig1 -= id3; assert (success == false); success = sig1 -= id2; assert (success == false);
    Foo foo;
    sig1 += Simple::slot (foo, &Foo::foo_bool);
    sig1 += Simple::slot (&foo, &Foo::foo_bool);
    sig1.emit (.5, 1, "12");

    Simple::Signal<void (std::string, int)> sig2;
    sig2 += [] (std::string msg, int) { accu += string_printf ("msg: %s", msg.c_str()); };
    sig2 += [] (std::string, int d)   { accu += string_printf (" *%d*\n", d); };
    sig2.emit ("in sig2", 17);

    accu += "DONE";

    const char *expected =
      "float: 0.30\n"
      "int: 4\n"
      "string: huhu\n"
      "Foo: 3.50\n"
      "Foo: 3.50\n"
      "msg: in sig2 *17*\n"
      "DONE";
    assert (accu == expected);
  }
};
std::string BasicSignalTests::accu;


class TestCollectorVector {
  static int handler1   ()  { return 1; }
  static int handler42  ()  { return 42; }
  static int handler777 ()  { return 777; }
  public:
  static void
  run ()
  {
    Simple::Signal<int (), Simple::CollectorVector<int>> sig_vector;
    sig_vector += handler777;
    sig_vector += handler42;
    sig_vector += handler1;
    sig_vector += handler42;
    sig_vector += handler777;
    std::vector<int> results = sig_vector.emit();
    const std::vector<int> reference = { 777, 42, 1, 42, 777, };
    assert (results == reference);
  }
};

class TestCollectorUntil0 {
  bool check1, check2;
  TestCollectorUntil0() : check1 (0), check2 (0) {}
  bool handler_true  ()  { check1 = true; return true; }
  bool handler_false ()  { check2 = true; return false; }
  bool handler_abort ()  { abort(); }
  public:
  static void
  run ()
  {
    TestCollectorUntil0 self;
    Simple::Signal<bool (), Simple::CollectorUntil0<bool>> sig_until0;
    sig_until0 += Simple::slot (self, &TestCollectorUntil0::handler_true);
    sig_until0 += Simple::slot (self, &TestCollectorUntil0::handler_false);
    sig_until0 += Simple::slot (self, &TestCollectorUntil0::handler_abort);
    assert (!self.check1 && !self.check2);
    const bool result = sig_until0.emit();
    assert (!result && self.check1 && self.check2);
  }
};

class TestCollectorWhile0 {
  bool check1, check2;
  TestCollectorWhile0() : check1 (0), check2 (0) {}
  bool handler_0     ()  { check1 = true; return false; }
  bool handler_1     ()  { check2 = true; return true; }
  bool handler_abort ()  { abort(); }
  public:
  static void
  run ()
  {
    TestCollectorWhile0 self;
    Simple::Signal<bool (), Simple::CollectorWhile0<bool>> sig_while0;
    sig_while0 += Simple::slot (self, &TestCollectorWhile0::handler_0);
    sig_while0 += Simple::slot (self, &TestCollectorWhile0::handler_1);
    sig_while0 += Simple::slot (self, &TestCollectorWhile0::handler_abort);
    assert (!self.check1 && !self.check2);
    const bool result = sig_while0.emit();
    assert (result == true && self.check1 && self.check2);
  }
};

static void
bench_simple_signal()
{
  Simple::Signal<void (void*, uint64_t)> sig_increment;
  sig_increment += test_counter_add2;
  const uint64_t start_counter = TestCounter::get();
  const uint64_t benchstart = timestamp_benchmark();
  uint64_t i;
  for (i = 0; i < 999999; i++)
    {
      sig_increment.emit (NULL, 1);
    }
  const uint64_t benchdone = timestamp_benchmark();
  const uint64_t end_counter = TestCounter::get();
  assert (end_counter - start_counter == i);
  printf ("OK\n  Benchmark: Simple::Signal: %fns per emission (size=%zu): ", size_t (benchdone - benchstart) * 1.0 / size_t (i),
          sizeof (sig_increment));
}

static void
bench_callback_loop()
{
  void (*counter_increment) (void*, uint64_t) = test_counter_add2;
  const uint64_t start_counter = TestCounter::get();
  const uint64_t benchstart = timestamp_benchmark();
  uint64_t i;
  for (i = 0; i < 999999; i++)
    {
      counter_increment (NULL, 1);
    }
  const uint64_t benchdone = timestamp_benchmark();
  const uint64_t end_counter = TestCounter::get();
  assert (end_counter - start_counter == i);
  printf ("OK\n  Benchmark: callback loop: %fns per round: ", size_t (benchdone - benchstart) * 1.0 / size_t (i));
}

uint64_t
TestCounter::get ()
{
  return test_counter_var;
}

void
TestCounter::set (uint64_t v)
{
  test_counter_var = v;
}

void
TestCounter::add2 (void*, uint64_t v)
{
  test_counter_var += v;
}

int
main (int   argc,
      char *argv[])
{
  printf ("Signal/Basic Tests: ");
  BasicSignalTests::run();
  printf ("OK\n");

  printf ("Signal/CollectorVector: ");
  TestCollectorVector::run();
  printf ("OK\n");

  printf ("Signal/CollectorUntil0: ");
  TestCollectorUntil0::run();
  printf ("OK\n");

  printf ("Signal/CollectorWhile0: ");
  TestCollectorWhile0::run();
  printf ("OK\n");

  printf ("Signal/Benchmark: Simple::Signal: ");
  bench_simple_signal();
  printf ("OK\n");

  printf ("Signal/Benchmark: callback loop: ");
  bench_callback_loop();
  printf ("OK\n");

  return 0;
}

#endif // DISABLE_TESTS

// g++ -Wall -O2 -std=gnu++0x -pthread simplesignal.cc -lrt && ./a.out
