/*************************************************** */
/* Rule Set Based Access Control                     */
/*                                                   */
/* Author and (c) 1999-2019: Amon Ott <ao@rsbac.org> */
/*                                                   */
/* Last modified: 11/Dec/2019                        */
/*************************************************** */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <time.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <rsbac/types.h>
#include <rsbac/syscalls.h>
#include <rsbac/error.h>
#include "nls.h"
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

char * progname;
char password[RSBAC_MAXNAMELEN] = "";
char hash_algo[RSBAC_UM_ALGO_NAME_LEN] = "";
rsbac_um_set_t vset = RSBAC_UM_VIRTUAL_KEEP;

void use(void)
    {
      printf(gettext("%s (RSBAC %s)\n***\n"), progname, VERSION);
      printf(gettext("Use: %s [flags] groupname\n"), progname);
      printf(gettext(" -h = this help, -- = no more flags,\n"));
      printf(gettext(" -p password = password in plaintext,\n"));
      printf(gettext(" -P = disable password,\n"));
      printf(gettext(" -Q password = encrypted password (from backup),\n"));
      printf(gettext(" -A hash-algo = hash algorithm to use, e.g. sha256 (default: use kernel default),\n"));
      printf(gettext(" -g name = change groupname,\n"));
      printf(gettext(" -t = set relative time-to-live in secs (role/type comp, admin, assign only)\n"));
      printf(gettext(" -T = set absolute time-to-live in secs (role/type comp, admin, assign only)\n"));
      printf(gettext(" -D = set relative time-to-live in days (role/type comp, admin, assign only)\n"));
      printf(gettext(" -S n = virtual user set n\n"));
      printf(gettext(" -N ta = transaction number (default = value of RSBAC_TA, if set, or 0)\n"));
    }

int password_read(char * to, char * from)
  {
    char * f = from;
    char * t = to;
    char   tmp[3];
    int i;

    if((strlen(from) != RSBAC_UM_PASS_LEN * 2) && (strlen(from) != RSBAC_UM_PWDATA_LEN * 2))
      {
        fprintf(stderr, "Wrong encrypted password length!\n");
        return -RSBAC_EINVALIDVALUE;
      }
    tmp[2] = 0;
    while(f[0] && f[1])
      {
        tmp[0] = f[0];
        tmp[1] = f[1];
        i = strtoul(tmp, 0, 16);
        if(i < 0 || i > 255)
          return -RSBAC_EINVALIDVALUE;
        *t = i;
        t++;
        f += 2;
      }
    return strlen(from) / 2;
  }

void mod_show_error(int res, char * item)
  {
    if(res < 0)
      {
        char tmp1[80];

        fprintf(stderr, "%s: %s\n",
                item,
                get_error_name(tmp1,res));
      }
  }

int main(int argc, char ** argv)
{
  int res = 0;
  rsbac_gid_t group;
  int verbose = 0;
  int err;
  union rsbac_um_mod_data_t data;
  int do_pass = 0;
  char * pass = NULL;
  char * crypt_pass = NULL;
  rsbac_boolean_t oldcryptpass = TRUE;
  char * name = NULL;
  int do_ttl = 0;
  rsbac_time_t ttl = 0;
  rsbac_list_ta_number_t ta_number = 0;
  u_int stopflags = FALSE;
  int i;

  locale_init();

  progname = argv[0];
  {
    char * env = getenv("RSBAC_TA");

    if(env)
      ta_number = strtoul(env,0,0);
  }
  while((argc > 1) && (argv[1][0] == '-') && !stopflags)
    {
      char * pos = argv[1];
      pos++;
      while(*pos)
        {
          switch(*pos)
            {
              case '-':
                stopflags = TRUE;
                break;
              case 'h':
                use();
                return 0;
              case 'v':
                verbose++;
                break;
              case 'p':
                if(argc > 2)
                  {
                    pass=argv[2];
                    do_pass = 1;
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing argument for parameter %c\n"), progname, *pos);
                break;
              case 'A':
                if(argc > 2)
                  {
                   strncpy(hash_algo, argv[2], RSBAC_UM_ALGO_NAME_LEN);
                    hash_algo[RSBAC_UM_ALGO_NAME_LEN - 1] = 0;
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing argument for parameter %c\n"), progname, *pos);
                break;
              case 'P':
                 pass = NULL;
                 do_pass = 1;
                 break;
              case 'Q':
                if(argc > 2)
                  {
                    err = password_read(password, argv[2]);
                    error_exit(err);
                    if (err == RSBAC_UM_PASS_LEN)
                      oldcryptpass = TRUE;
                    else if (err == RSBAC_UM_PWDATA_LEN)
                      oldcryptpass = FALSE;
                    else
                      error_exit(-RSBAC_EINVALIDVALUE);
                    crypt_pass = password;
                    do_pass = 1;
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing argument for parameter %c\n"), progname, *pos);
                break;
              case 'g':
                if(argc > 2)
                  {
                    name=argv[2];
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing argument for parameter %c\n"), progname, *pos);
                break;
              case 't':
                if(argc > 2)
                  {
                    res = strtokmgu32(argv[2], &ttl);
                    error_exit(res);
                    do_ttl = 1;
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing ttl value for parameter %c\n"), progname, *pos);
                break;
              case 'D':
                if(argc > 2)
                  {
                    ttl = 86400 * strtoul(argv[2], 0, 10);
                    do_ttl = 1;
                    argc--;
                    argv++;
                  }
                else
                  fprintf(stderr, gettext("%s: missing ttl value for parameter %c\n"), progname, *pos);
                break;
              case 'T':
                if(argc > 2)
                  {
                    rsbac_time_t now = time(NULL);
                    ttl = strtoul(argv[2], 0, 10);
                    if(ttl > now)
                      {
                        ttl -= now;
                        do_ttl = 1;
                        argc--;
                        argv++;
                      }
                    else
                      {
                        fprintf(stderr,
                                gettext("%s: ttl value for parameter %c is in the past, exiting\n"), progname, *pos);
                        exit(1);
                      }
                  }
                else
                  fprintf(stderr, gettext("%s: missing ttl value for parameter %c\n"), progname, *pos);
                break;
              case 'N':
                if(argc > 2)
                  {
                    ta_number = strtoul(argv[2], 0, 10);
                    argc--;
                    argv++;
                  }
                else
                  {
                    fprintf(stderr, gettext("%s: missing transaction number value for parameter %c\n"), progname, *pos);
                    exit(1);
                  }
                break;
              case 'S':
                if(argc > 2)
                  {
                    if (rsbac_get_vset_num(argv[2], &vset))
                      {
                        fprintf(stderr, gettext("%s: invalid virtual set number for parameter %c\n"), progname, *pos);
                        exit(1);
                      }
                    group = RSBAC_GEN_GID(vset, group);
                    argc--;
                    argv++;
                  }
                else
                  {
                    fprintf(stderr, gettext("%s: missing virtual set number for parameter %c\n"), progname, *pos);
                    exit(1);
                  }
                break;

              default:
                fprintf(stderr, gettext("%s: unknown parameter %c\n"), progname, *pos);
                exit(1);
            }
          pos++;
        }
      argv++;
      argc--;
    }

  if (argc > 1)
    {
      for(i=1; i< argc; i++)
        {
          rsbac_gid_t group = RSBAC_GEN_GID(vset, RSBAC_NO_GROUP);
          if(rsbac_um_get_gid(ta_number, argv[i], &group))
            {
              char * tmp_name = argv[i];
              char * p = tmp_name;
              rsbac_um_set_t tmp_vset = vset;
    
              while (*p && (*p != '/'))
                p++;
              if (*p) {
                        *p = 0;
                        if (rsbac_get_vset_num(tmp_name, &tmp_vset))
                          {
                            fprintf(stderr, gettext("%s: invalid virtual set number %s, skipping\n"), progname, tmp_name);
                            continue;
                          }
                        *p = '/';
                        p++;
                        tmp_name = p;
              }
              group = strtoul(tmp_name, NULL, 0);
              if(!group && strcmp(tmp_name,"0"))
                {
                  fprintf(stderr, gettext("%s: Unknown group %s\n"), progname, argv[i]);
                  return 1;
                }
              group = RSBAC_GEN_GID(tmp_vset, group);
            }
         group = RSBAC_GEN_GID(vset, RSBAC_GID_NUM(group));
         res = rsbac_um_get_group_item(ta_number, group, UM_name, &data);
          if(res)
            {
              fprintf(stderr, gettext("%s: Unknown group %s\n"), progname, argv[i]);
              exit(1);
            }
          if(do_pass)
            {
              if(crypt_pass)
                {
                  if (oldcryptpass) {
                    memcpy(data.string, crypt_pass, RSBAC_UM_PASS_LEN);
                    memset(crypt_pass, 0, RSBAC_UM_PASS_LEN);
                    res = rsbac_um_mod_group(ta_number, group, UM_cryptpass, &data);
                  } else {
                    memcpy(data.string, crypt_pass, RSBAC_UM_PWDATA_LEN);
                    memset(crypt_pass, 0, RSBAC_UM_PWDATA_LEN);
                    res = rsbac_um_mod_group(ta_number, group, UM_cryptpass_algo, &data);
                  }
                  memset(&data, 0, sizeof(data));
                }
              else
              if(pass)
                {
                  if (hash_algo[0] == 0) {
                    strncpy(data.string, pass, RSBAC_MAXNAMELEN);
                    data.string[RSBAC_MAXNAMELEN - 1] = 0;
                    memset(pass, 0, strlen(pass));
                    res = rsbac_um_mod_group(ta_number, group, UM_pass, &data);
                  } else {
                    strncpy(data.string, hash_algo, RSBAC_UM_ALGO_NAME_LEN);
                    data.string[RSBAC_UM_ALGO_NAME_LEN - 1] = 0;
                    strncpy(data.string + RSBAC_UM_ALGO_NAME_LEN, pass, RSBAC_MAXNAMELEN - RSBAC_UM_ALGO_NAME_LEN);
                    data.string[RSBAC_MAXNAMELEN - 1] = 0;
                    memset(pass, 0, strlen(pass));
                    res = rsbac_um_mod_group(ta_number, group, UM_pass_algo, &data);
                  }
                  memset(&data, 0, sizeof(data));
                }
              else
                res = rsbac_um_mod_group(ta_number, group, UM_pass, NULL);
              mod_show_error(res, "Password");
            }
          if(name)
            {
              strncpy(data.string, name, RSBAC_MAXNAMELEN);
              data.string[RSBAC_MAXNAMELEN - 1] = 0;
              res = rsbac_um_mod_group(ta_number, group, UM_name, &data);
              mod_show_error(res, "Groupname");
            }
          if(do_ttl)
            {
              data.ttl = ttl;
              res = rsbac_um_mod_group(ta_number, group, UM_ttl, &data);
              mod_show_error(res, "TTL");
            }
        }
      exit(0);
    }
  else
    {
      use();
      return 1;
    }
  return (res);
}

