{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Enter bit_length: 4\n",
      "193 191\n"
     ]
    }
   ],
   "source": [
    "import RSA_module\n",
    "\n",
    "bit_length = int(input(\"Enter bit_length: \"))\n",
    "\n",
    "public, private = RSA_module.generate_keypair(2**bit_length)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RSA Encryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Write message: 7enTropy7\n",
      "\n",
      "Encrypted message: 10284142620619325088624438111132138110284\n"
     ]
    }
   ],
   "source": [
    "msg = input(\"\\nWrite message: \")\n",
    "encrypted_msg, encryption_obj = RSA_module.encrypt(msg, public)\n",
    "\n",
    "print(\"\\nEncrypted message: \" + encrypted_msg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RSA Decryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Decrypted message using RSA Algorithm: 7enTropy7\n"
     ]
    }
   ],
   "source": [
    "decrypted_msg = RSA_module.decrypt(encryption_obj, private)\n",
    "\n",
    "print(\"\\nDecrypted message using RSA Algorithm: \" + decrypted_msg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shor's Quantum Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import gcd,log\n",
    "from random import randint\n",
    "import numpy as np\n",
    "from qiskit import *\n",
    "\n",
    "qasm_sim = qiskit.Aer.get_backend('qasm_simulator')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def period(a,N):\n",
    "    \n",
    "    available_qubits = 16 \n",
    "    r=-1   \n",
    "    \n",
    "    if N >= 2**available_qubits:\n",
    "        print(str(N)+' is too big for IBMQX')\n",
    "    \n",
    "    qr = QuantumRegister(available_qubits)   \n",
    "    cr = ClassicalRegister(available_qubits)\n",
    "    qc = QuantumCircuit(qr,cr)\n",
    "    x0 = randint(1, N-1) \n",
    "    x_binary = np.zeros(available_qubits, dtype=bool)\n",
    "    \n",
    "    for i in range(1, available_qubits + 1):\n",
    "        bit_state = (N%(2**i)!=0)\n",
    "        if bit_state:\n",
    "            N -= 2**(i-1)\n",
    "        x_binary[available_qubits-i] = bit_state    \n",
    "    \n",
    "    for i in range(0,available_qubits):\n",
    "        if x_binary[available_qubits-i-1]:\n",
    "            qc.x(qr[i])\n",
    "    x = x0\n",
    "    \n",
    "    while np.logical_or(x != x0, r <= 0):\n",
    "        r+=1\n",
    "        qc.measure(qr, cr) \n",
    "        for i in range(0,3): \n",
    "            qc.x(qr[i])\n",
    "        qc.cx(qr[2],qr[1])\n",
    "        qc.cx(qr[1],qr[2])\n",
    "        qc.cx(qr[2],qr[1])\n",
    "        qc.cx(qr[1],qr[0])\n",
    "        qc.cx(qr[0],qr[1])\n",
    "        qc.cx(qr[1],qr[0])\n",
    "        qc.cx(qr[3],qr[0])\n",
    "        qc.cx(qr[0],qr[1])\n",
    "        qc.cx(qr[1],qr[0])\n",
    "        \n",
    "        result = execute(qc,backend = qasm_sim, shots=1024).result()\n",
    "        counts = result.get_counts()\n",
    "        \n",
    "        #print(qc)\n",
    "        \n",
    "        results = [[],[]]\n",
    "        for key,value in counts.items(): \n",
    "            results[0].append(key)\n",
    "            results[1].append(int(value))\n",
    "        s = results[0][np.argmax(np.array(results[1]))]\n",
    "    return r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shors_breaker(N):\n",
    "    N = int(N)\n",
    "    while True:\n",
    "        a=randint(0,N-1)\n",
    "        g=gcd(a,N)\n",
    "        if g!=1 or N==1:\n",
    "            return g,N//g\n",
    "        else:\n",
    "            r=period(a,N) \n",
    "            if r % 2 != 0:\n",
    "                continue\n",
    "            elif pow(a,r//2,N)==-1:\n",
    "                continue\n",
    "            else:\n",
    "                p=gcd(pow(a,r//2)+1,N)\n",
    "                q=gcd(pow(a,r//2)-1,N)\n",
    "                if p==N or q==N:\n",
    "                    continue\n",
    "                return p,q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def modular_inverse(a,m): \n",
    "    a = a % m; \n",
    "    for x in range(1, m) : \n",
    "        if ((a * x) % m == 1) : \n",
    "            return x \n",
    "    return 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_shor = public[1]\n",
    "assert N_shor>0,\"Input must be positive\"\n",
    "p,q = shors_breaker(N_shor)\n",
    "phi = (p-1) * (q-1)  \n",
    "d_shor = modular_inverse(public[0], phi) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cracking RSA using Shor's Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Message Cracked using Shors Algorithm: 7enTropy7\n",
      "\n"
     ]
    }
   ],
   "source": [
    "decrypted_msg = RSA_module.decrypt(encryption_obj, (d_shor,N_shor))\n",
    "\n",
    "print('\\nMessage Cracked using Shors Algorithm: ' + decrypted_msg + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}