import { KEYUTIL, KJUR } from 'jsrsasign';
import isNil from 'lodash-es/isNil';
import { inject, singleton } from 'tsyringe';
import { InvalidJwtTokenError } from './errors/InvalidJwtTokenError';
import { InvalidUserPoolConfigError } from './errors/InvalidUserPoolConfigError';
import { Optional } from 'lib/types/Optional';
import {
  IAccessTokenPayload,
  IIdTokenPayload,
  JwtTokenHeader,
  OidcTokenType,
} from 'app/cross-cutting-concerns/authentication/interfaces/IAuthenticationTokens';
import type { IUserPoolConfig } from 'config/user-pool-config';

type JwtTokenPayload = IIdTokenPayload | IAccessTokenPayload;
interface DecodedToken {
  header: JwtTokenHeader;
  payload: JwtTokenPayload;
}

@singleton()
export class UserPoolUtils {
  constructor(@inject('UserPoolConfig') private readonly userPoolConfig: IUserPoolConfig) {}

  public verifyOidcToken = (token: Optional<string>, oidcTokenType: OidcTokenType): boolean => {
    const awsRegion = process.env.REACT_APP_AWS_REGION ?? '';
    const awsUserPoolId = process.env.REACT_APP_AWS_USER_POOL_ID ?? '';
    const awsUserPoolClientId = process.env.REACT_APP_AWS_USER_POOL_WEB_CLIENT_ID ?? '';
    const issuer = `https://cognito-idp.${awsRegion}.amazonaws.com/${awsUserPoolId}`;
    const userPoolConfig = this.userPoolConfig[awsUserPoolId];

    if (isNil(token)) {
      throw new InvalidJwtTokenError();
    }

    const decodedToken = this.decodeJwtToken(token);

    if (!userPoolConfig) {
      throw new InvalidUserPoolConfigError();
    }

    const jsonWebKey = userPoolConfig.keys.find(key => key.kid === decodedToken.header.kid);

    if (!jsonWebKey) {
      throw new InvalidJwtTokenError();
    }

    const publicKey = KEYUTIL.getKey(jsonWebKey);

    const isTokenValid = KJUR.jws.JWS.verifyJWT(token, publicKey as jsrsasign.RSAKey, {
      alg: ['RS256'],
      iss: [issuer],
    });

    if (decodedToken.payload.token_use !== oidcTokenType) {
      throw new InvalidJwtTokenError();
    }

    if (oidcTokenType === OidcTokenType.ID && decodedToken.payload.aud !== awsUserPoolClientId) {
      throw new InvalidJwtTokenError();
    }

    if (oidcTokenType === OidcTokenType.ACCESS && decodedToken.payload.client_id !== awsUserPoolClientId) {
      throw new InvalidJwtTokenError();
    }

    return isTokenValid;
  };

  private decodeJwtToken = (token: string): DecodedToken => {
    const { headerObj: header, payloadObj: payload } = KJUR.jws.JWS.parse(token);

    if (!header || !payload) {
      throw new InvalidJwtTokenError();
    }

    return {
      header: header as unknown as JwtTokenHeader,
      payload: payload as unknown as JwtTokenPayload,
    };
  };
}
