/*************************************************** */
/* Rule Set Based Access Control                     */
/*                                                   */
/* Author and (c) 1999-2019: Amon Ott <ao@rsbac.org> */
/*                                                   */
/* Last modified: 11/Dec/2019                        */
/*************************************************** */

#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <time.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <grp.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <rsbac/types.h>
#include <rsbac/syscalls.h>
#include <rsbac/error.h>
#include "nls.h"
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

  char * progname;
  char password[RSBAC_MAXNAMELEN] = "";
  char hash_algo[RSBAC_UM_ALGO_NAME_LEN] = "";
  char * crypt_pass = NULL;
  rsbac_time_t ttl = 0;
  int verbose = 0;
  int useold = 0;
  int sysgroup = 0;
  int addallold = 0;
  int maygidreplace = 0;
  int gidreplace = 0;
  int alwaysreplace = 0;
  rsbac_list_ta_number_t ta_number = 0;
  rsbac_um_set_t vset = RSBAC_UM_VIRTUAL_KEEP;

void use(void)
    {
      printf(gettext("%s (RSBAC %s)\n***\n"), progname, VERSION);
      printf(gettext("Use: %s [flags] groupname\n"), progname);
      printf(gettext(" -h = this help, -- = no more flags,\n"));
      printf(gettext(" -p password = password in plaintext,\n"));
      printf(gettext(" -A hash-algo = hash algorithm to use, e.g. sha256 (default: use kernel default),\n"));
      printf(gettext(" -g gid = gid to use,\n"));
      printf(gettext(" -G = create system group (gid >= 100),\n"));
      printf(gettext(" -r = replace same groupname with different gid (requires -g),\n"));
      printf(gettext(" -R = replace same groupname in any case,\n"));
      printf(gettext(" -t = set relative time-to-live in secs (role/type comp, admin, assign only)\n"));
      printf(gettext(" -T = set absolute time-to-live in secs (role/type comp, admin, assign only)\n"));
      printf(gettext(" -D = set relative time-to-live in days (role/type comp, admin, assign only)\n"));
      printf(gettext(" -o = use values from old group entry,\n"));
      printf(gettext(" -O = add all existing groups (implies -o)\n"));
      printf(gettext(" -C group = copy existing group\n"));
      printf(gettext(" -S n = virtual user set n\n"));
      printf(gettext(" -N ta = transaction number (default = value of RSBAC_TA, if set, or 0)\n"));
    }

int password_read(char * to, char * from)
  {
    char * f = from;
    char * t = to;
    char   tmp[3];
    int i;

    if(strlen(from) != RSBAC_UM_PASS_LEN * 2)
      {
        fprintf(stderr, "Wrong encrypted password length!\n");
        return -RSBAC_EINVALIDVALUE;
      }
    tmp[2] = 0;
    while(f[0] && f[1])
      {
        tmp[0] = f[0];
        tmp[1] = f[1];
        i = strtoul(tmp, 0, 16);
        if(i < 0 || i > 255)
          return -RSBAC_EINVALIDVALUE;
        *t = i;
        t++;
        f += 2;
      }
    return 0;
  }

int process(char * name, rsbac_gid_t group,
            struct rsbac_um_group_entry_t entry,
            char ** gr_mem)
  {
      int res;

      if(useold)
        {
          if(verbose) {
            if(RSBAC_GID_SET(group) != RSBAC_UM_VIRTUAL_KEEP)
              printf("Adding old group %s with gid %u/%u\n",
                     name, RSBAC_GID_SET(group), RSBAC_GID_NUM(group));
            else
              printf("Adding old group %s with gid %u\n",
                     name, RSBAC_GID_NUM(group));
          }
        }
      else if(sysgroup) {
          if(RSBAC_GID_NUM(group) == RSBAC_NO_GROUP)
            group = RSBAC_GEN_GID(RSBAC_GID_SET(group), 100);
          while (rsbac_um_group_exists(ta_number, group))
            group++;
      }
      if(verbose) {
          if(RSBAC_GID_SET(group) != RSBAC_UM_VIRTUAL_KEEP)
            printf("Adding group %u/%u:%s", RSBAC_GID_SET(group), RSBAC_GID_NUM(group), name);
          else
            printf("Adding group %u:%s", RSBAC_GID_NUM(group), name);
          if(alwaysreplace)
            printf(" (replace existing)");
          else
            if(gidreplace && maygidreplace)
              printf(" (replace different uid)");
          printf("\n");
      }
      strncpy(entry.name, name, RSBAC_UM_NAME_LEN);
      entry.name[RSBAC_UM_NAME_LEN - 1] = 0;
      if((gidreplace && maygidreplace) || alwaysreplace)
        {
          rsbac_gid_t tmp_group = RSBAC_GEN_GID(vset, RSBAC_NO_GROUP);
          if(   !rsbac_um_get_gid(ta_number, name, &tmp_group)
             && (   alwaysreplace
                 || (   (RSBAC_GID_SET(group) == RSBAC_UM_VIRTUAL_KEEP)
                     && (RSBAC_GID_NUM(tmp_group) != RSBAC_GID_NUM(group))
                    )
                 || (   (RSBAC_GID_SET(group) != RSBAC_UM_VIRTUAL_KEEP)
                     && (tmp_group != group)
                    )
                )
            )
            {
              if(verbose)
                {
                  if (RSBAC_GID_SET(group) == RSBAC_UM_VIRTUAL_KEEP)
                    printf("First removing group %u:%s, then adding %u:%s\n",
                           RSBAC_GID_NUM(tmp_group), entry.name,
                           RSBAC_GID_NUM(group), entry.name);
                  else
                    printf("First removing group %u/%u:%s, then adding %u/%u:%s\n",
                           RSBAC_GID_SET(tmp_group), RSBAC_GID_NUM(tmp_group), entry.name,
                           RSBAC_GID_SET(tmp_group), RSBAC_GID_NUM(group), entry.name);
                }
              res = rsbac_um_remove_group(ta_number, tmp_group);
              if(res)
                {
                  if (vset != RSBAC_UM_VIRTUAL_KEEP)
                    fprintf(stderr, "%u/%s: ", RSBAC_GID_SET(group), name);
                  else
                    fprintf(stderr, "%s: ", name);
                  show_error(res);
                  return res;
                }
            }
        }
      if (hash_algo[0] == 0)
        res = rsbac_um_add_group(ta_number, group, &entry, password, ttl);
      else
        res = rsbac_um_add_group_hash(ta_number, group, &entry, password, hash_algo, ttl);
      if(res)
        {
          fprintf(stderr, "%s: ", name);
          show_error(res);
        }
      if(gr_mem)
        {
          rsbac_uid_t tmp_uid;

          while(*gr_mem)
            {
              if(verbose)
                printf("Adding group %s member %s\n", name, *gr_mem);
              tmp_uid = RSBAC_GEN_UID(vset, RSBAC_NO_USER);
              res = rsbac_um_get_uid(ta_number, *gr_mem, &tmp_uid);
              if(res)
                {
                  if(vset != RSBAC_UM_VIRTUAL_KEEP)
                    fprintf(stderr, "Lookup group %u/%s member %s: ", vset, name, *gr_mem);
                  else
                    fprintf(stderr, "Lookup group %s member %s: ", name, *gr_mem);
                  show_error(res);
                }
              else
                {
                  res = rsbac_um_add_gm(ta_number, tmp_uid,
                                        RSBAC_GID_NUM(group), ttl);
                  if(res)
                    {
                      if(vset != RSBAC_UM_VIRTUAL_KEEP)
                        fprintf(stderr, "Adding group %u/%s member %s (uid %u): ",
                                vset, name, *gr_mem, RSBAC_UID_NUM(tmp_uid));
                      else
                        fprintf(stderr, "Adding group %s member %s (uid %u): ", name, *gr_mem, RSBAC_UID_NUM(tmp_uid));
                      show_error(res);
                    }
                }
              gr_mem++;
            }
        }
      if(crypt_pass)
        {
          union rsbac_um_mod_data_t data;

          memcpy(data.string, crypt_pass, RSBAC_UM_PASS_LEN);
          res = rsbac_um_mod_group(ta_number, group, UM_cryptpass, &data);
          show_error(res);
        }
     return res;
  }

int fill_entry(rsbac_gid_t group, struct rsbac_um_group_entry_t * entry_p)
{
  int res;
  union rsbac_um_mod_data_t data;

  res = rsbac_um_get_group_item(ta_number, group, UM_name, &data);
  if(!res)
    strcpy(entry_p->name, data.string);
  else
    return res;
  res = rsbac_um_get_group_item(ta_number, group, UM_ttl, &data);
  if(!res)
    ttl = data.ttl;
  return 0;
}

int main(int argc, char ** argv)
{
  int res = 0;
  struct rsbac_um_group_entry_t entry = DEFAULT_UM_G_ENTRY;
  rsbac_gid_t group = RSBAC_GEN_GID(vset, RSBAC_NO_GROUP);
  u_int stopflags = FALSE;

  locale_init();

  progname = argv[0];
  {
    char * env = getenv("RSBAC_TA");

    if(env)
      ta_number = strtoul(env,0,0);
  }
  while((argc > 1) && (argv[1][0] == '-') && !stopflags)
    {
      char * pos = argv[1];
      pos++;
      while(*pos)
        {
          switch(*pos)
            {
              case '-':
                stopflags = TRUE;
                break;
              case 'h':
                use();
                return 0;
              case 'v':
                verbose++;
                break;
              case 'o':
                useold = 1;
                break;
              case 'G':
                sysgroup = 1;
                break;
              case 'O':
                addallold = 1;
                useold = 1;
                break;
              case 'C':
                if(argc > 2)
                  {
                    rsbac_gid_t egroup = RSBAC_GEN_GID(vset, RSBAC_NO_GROUP);

                    if(rsbac_um_get_gid(ta_number, argv[2], &egroup))
                      {
                        char * tmp_name = argv[2];
                        char * p = tmp_name;
                        rsbac_um_set_t tmp_vset = vset;

                        while (*p && (*p != '/'))
                          p++;
                        if (*p) {
                          *p = 0;
                          tmp_vset = strtoul(tmp_name, NULL, 0);
                          *p = '/';
                          p++;
                          tmp_name = p;
                        }
                        egroup = strtoul(tmp_name, NULL, 0);
                        if(!egroup && strcmp(tmp_name,"0"))
                          {
                            fprintf(stderr, gettext("%s: Unknown group %s\n"), progname, argv[2]);
                            return 1;
                          }
                        egroup = RSBAC_GEN_GID(tmp_vset, egroup);
                      }
                    if (fill_entry (egroup, &entry)) {
                      fprintf(stderr, gettext("%s: Reading group %s (%u/%u) failed, exiting!\n"),
                              progname, argv[2], RSBAC_GID_SET(egroup), RSBAC_GID_NUM(egroup));
                      return 1;
                    }
                    group = egroup;
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing argument for parameter %c\n"), progname, *pos);
                break;
              case 'p':
                if(argc > 2)
                  {
                    strncpy(password, argv[2], RSBAC_MAXNAMELEN);
                    password[RSBAC_MAXNAMELEN - 1] = 0;
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing argument for parameter %c\n"), progname, *pos);
                break;
              case 'A':
                if(argc > 2)
                  {
                   strncpy(hash_algo, argv[2], RSBAC_UM_ALGO_NAME_LEN);
                    hash_algo[RSBAC_UM_ALGO_NAME_LEN - 1] = 0;
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing argument for parameter %c\n"), progname, *pos);
                break;
              case 'Q':
                if(argc > 2)
                  {
                    crypt_pass = malloc(RSBAC_MAXNAMELEN);
                    if(!crypt_pass)
                      error_exit(-ENOMEM);
                    res = password_read(crypt_pass, argv[2]);
                    error_exit(res);
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing argument for parameter %c\n"), progname, *pos);
                break;
              case 'g':
                if(argc > 2)
                  {
                    group = RSBAC_GEN_GID(vset, strtoul(argv[2],0,0));
                    maygidreplace = 1;
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing argument for parameter %c\n"), progname, *pos);
                break;
              case 'r':
                gidreplace = 1;
                break;
              case 'R':
                alwaysreplace = 1;
                break;
              case 't':
                if(argc > 2)
                  {
                    res = strtokmgu32(argv[2], &ttl);
                    error_exit(res);
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing ttl value for parameter %c\n"), progname, *pos);
                break;
              case 'D':
                if(argc > 2)
                  {
                    ttl = 86400 * strtoul(argv[2], 0, 10);
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing ttl value for parameter %c\n"), progname, *pos);
                break;
              case 'T':
                if(argc > 2)
                  {
                    rsbac_time_t now = time(NULL);
                    ttl = strtoul(argv[2], 0, 10);
                    if(ttl > now)
                      {
                        ttl -= now;
                        argc--;
                        argv++;
                      }
                    else
                      {
                        fprintf(stderr,
                                gettext("%s: ttl value for parameter %c is in the past, exiting\n"), progname, *pos);
                        exit(1);
                      }
                  }
                else
                  fprintf(stderr, gettext("%s: missing ttl value for parameter %c\n"), progname, *pos);
                break;
              case 'N':
                if(argc > 2)
                  {
                    ta_number = strtoul(argv[2], 0, 10);
                    argc--;
                    argv++;
                  }
                else
                  {
                    fprintf(stderr, gettext("%s: missing transaction number value for parameter %c\n"), progname, *pos);
                    exit(1);
                  }
                break;
              case 'S':
                if(argc > 2)
                  {
                    if (rsbac_get_vset_num(argv[2], &vset))
                      {
                        fprintf(stderr, gettext("%s: invalid virtual set number for parameter %c\n"), progname, *pos);
                        exit(1);
                      }
                    group = RSBAC_GEN_GID(vset, group);
                    argc--;
                    argv++;
                  }
                else
                  {
                    fprintf(stderr, gettext("%s: missing virtual set number for parameter %c\n"), progname, *pos);
                    exit(1);
                  }
                break;

              default:
                fprintf(stderr, gettext("%s: unknown parameter %c\n"), progname, *pos);
                exit(1);
            }
          pos++;
        }
      argv++;
      argc--;
    }

  if(addallold)
    {
      struct group * group_info_p;

      setgrent();
      while((group_info_p = getgrent()))
        process(group_info_p->gr_name, group_info_p->gr_gid,
                entry, group_info_p->gr_mem);
      endgrent();
      memset(password, 0, RSBAC_MAXNAMELEN);
      exit(0);
    }
  else
  if (argc > 1)
    {
      int i;
      struct group * group_info_p;

      for(i=1; i< argc; i++)
        {
          if(useold)
            {
              group_info_p = getgrnam(argv[i]);
              if(!group_info_p)
                fprintf(stderr, "%s: old entry not found!\n", argv[i]);
              else
                process(group_info_p->gr_name, group_info_p->gr_gid,
                        entry, group_info_p->gr_mem);
            }
          else {
            char * tmp_name = argv[i];
            char * p = tmp_name;
            rsbac_um_set_t tmp_vset = vset;

            while (*p && (*p != '/'))
              p++;
            if (*p) {
              *p = 0;
              if (rsbac_get_vset_num(tmp_name, &tmp_vset))
                {
                  fprintf(stderr, gettext("%s: invalid virtual set number %s, skipping\n"), progname, tmp_name);
                  continue;
                }
              *p = '/';
              p++;
              tmp_name = p;
            }
            process(tmp_name, RSBAC_GEN_GID(tmp_vset, group), entry, NULL);
          }
        }
      memset(password, 0, RSBAC_MAXNAMELEN);
      exit(0);
    }
  else
    {
      use();
      return 1;
    }
  return (res);
}
