"""See http://www.seanet.com/~bugbee/crypto/salsa20/salsa20.py for comments etc""" import struct def _rROL32(a,b): return ((a << b) | (a >> (32 - b))) & 0xffffffff class Salsa20(object): """s20 = Salsa20(key, iv, 8/12/20); ciphertext = s20.encrypt(message)""" def __init__( self, key, iv=b'\0\0\0\0\0\0\0\0', rounds=20 ): if type(key)!=bytes or type(iv)!=bytes: raise Exception('key and iv must be bytes') if len(key) not in (16,32): raise Exception('key must be either 16 or 32 bytes') if len(iv) != 8: raise Exception('iv must be 8 bytes') if rounds not in (20, 12, 8): raise Exception('number of rounds must be 8, 12, or 20') self._key_setup(key); self.iv_setup(iv); self.ROUNDS = rounds def _key_setup( self, key ): """key is converted to a list of 4-byte unsigned integers (32 bits); call iv_setup afterwards""" ks = [0x61707865, 0,0,0,0, 0x3320646e, 0,0,0,0, 0x79622d32, 0,0,0,0, 0x6b206574] if len(key) == 16: ks[5],ks[10], key = 0x3120646e,0x79622d36, key+key k = struct.unpack('<8I', key) ks[1:5] = k[:4]; ks[11:15] = k[4:] self.key_state = ks def iv_setup( self, iv ): """self.state and other working strucures are lists of 4-byte unsigned integers (32 bits). The iv should never be reused with the same key value, but it is not a secret. E.g. use time, etc. Prepending 8 bytes of iv to the ciphertext is the usual way to pass it to decoder.""" if len(iv) != 8: raise Exception('iv must be 8 bytes') iv_state = self.key_state[:] iv_state[6],iv_state[7] = struct.unpack('<2I', iv) iv_state[8],iv_state[9] = 0,0 # 8 will be counter, 9 will stay zero self.state = iv_state self.lastchunk = 64 # flag to ensure all but the last; chunks are multiple of 64 bytes def _salsa20_scramble(self): # 64 bytes in """ self.state and other working strucures are lists of 4-byte unsigned integers (32 bits). output must be converted to bytestring before return.""" x = self.state[:] # make a copy and work with it _ROL32=_rROL32 for i in range(self.ROUNDS//2): x[ 4] ^= _ROL32( (x[ 0]+x[12]) & 0xffffffff, 7) x[ 8] ^= _ROL32( (x[ 4]+x[ 0]) & 0xffffffff, 9) x[12] ^= _ROL32( (x[ 8]+x[ 4]) & 0xffffffff, 13) x[ 0] ^= _ROL32( (x[12]+x[ 8]) & 0xffffffff, 18) x[ 9] ^= _ROL32( (x[ 5]+x[ 1]) & 0xffffffff, 7) x[13] ^= _ROL32( (x[ 9]+x[ 5]) & 0xffffffff, 9) x[ 1] ^= _ROL32( (x[13]+x[ 9]) & 0xffffffff, 13) x[ 5] ^= _ROL32( (x[ 1]+x[13]) & 0xffffffff, 18) x[14] ^= _ROL32( (x[10]+x[ 6]) & 0xffffffff, 7) x[ 2] ^= _ROL32( (x[14]+x[10]) & 0xffffffff, 9) x[ 6] ^= _ROL32( (x[ 2]+x[14]) & 0xffffffff, 13) x[10] ^= _ROL32( (x[ 6]+x[ 2]) & 0xffffffff, 18) x[ 3] ^= _ROL32( (x[15]+x[11]) & 0xffffffff, 7) x[ 7] ^= _ROL32( (x[ 3]+x[15]) & 0xffffffff, 9) x[11] ^= _ROL32( (x[ 7]+x[ 3]) & 0xffffffff, 13) x[15] ^= _ROL32( (x[11]+x[ 7]) & 0xffffffff, 18) x[ 1] ^= _ROL32( (x[ 0]+x[ 3]) & 0xffffffff, 7) x[ 2] ^= _ROL32( (x[ 1]+x[ 0]) & 0xffffffff, 9) x[ 3] ^= _ROL32( (x[ 2]+x[ 1]) & 0xffffffff, 13) x[ 0] ^= _ROL32( (x[ 3]+x[ 2]) & 0xffffffff, 18) x[ 6] ^= _ROL32( (x[ 5]+x[ 4]) & 0xffffffff, 7) x[ 7] ^= _ROL32( (x[ 6]+x[ 5]) & 0xffffffff, 9) x[ 4] ^= _ROL32( (x[ 7]+x[ 6]) & 0xffffffff, 13) x[ 5] ^= _ROL32( (x[ 4]+x[ 7]) & 0xffffffff, 18) x[11] ^= _ROL32( (x[10]+x[ 9]) & 0xffffffff, 7) x[ 8] ^= _ROL32( (x[11]+x[10]) & 0xffffffff, 9) x[ 9] ^= _ROL32( (x[ 8]+x[11]) & 0xffffffff, 13) x[10] ^= _ROL32( (x[ 9]+x[ 8]) & 0xffffffff, 18) x[12] ^= _ROL32( (x[15]+x[14]) & 0xffffffff, 7) x[13] ^= _ROL32( (x[12]+x[15]) & 0xffffffff, 9) x[14] ^= _ROL32( (x[13]+x[12]) & 0xffffffff, 13) x[15] ^= _ROL32( (x[14]+x[13]) & 0xffffffff, 18) for i in range(16): x[i] = (x[i] + self.state[i]) & 0xffffffff self.state[8] += 1 # we don't do carry to state[9] b/c it's only after 274.8779 GB output = struct.pack('<16I',*x) return output # 64 bytes out def encrypt( self, datain ): """ datain and dataout are bytestrings. submited in chunks of 64 bytes; only last may be shorter""" if self.lastchunk != 64: raise Exception('size of last chunk not a multiple of 64 bytes') n_left = len(datain) dataout = bytearray(n_left) # we can set individual members of bytearray (i.e. it's mutable) stream = self._salsa20_scramble() for i in range(min(n_left,64)): dataout[i] = datain[i] ^ stream[i] #dataout[p:q] = bytes(stream[i]^d for i,d in enumerate(datain[p:q])) # xor n_done = 64 while n_left > 64: n_left -= 64 stream = self._salsa20_scramble() for i in range(min(n_left,64)): dataout[n_done+i] = datain[n_done+i] ^ stream[i] n_done += 64 self.lastchunk = n_left return bytes(dataout) decrypt = encrypt # same fn for decrypting import time, hashlib def check_sum(x): return hashlib.sha256(x).digest() # _SHA_LEN bytes def generate_iv(): return b'S20c'+struct.pack("_HEADLEN: o.write( s20.encrypt(data[_HEADLEN:]) ) return len(data) def decrypt_file( key, iname, oname ): if iname==oname: raise Exception('File names must be different') if type(key)!=bytes or len(key) not in (16,_KEY_LEN): raise Exception('Illegal key') with open( iname, "rb" ) as f: iv = f.read(8) if len(iv)<8: raise Exception('Too short file') ok,tm = analyze_iv(iv) if not ok: raise Exception('No signature') ciph = f.read(_SHA_LEN+_HEADLEN) if len(ciph)<_SHA_LEN: raise Exception('Too short file') tail = f.read() # till the end s20 = Salsa20(key[:_KEY_LEN],iv) orig = s20.decrypt(ciph) sha,head = orig[:_SHA_LEN],orig[_SHA_LEN:] if check_sum( head ) != sha: raise Exception('Wrong password/corrupted data') with open( oname, "wb" ) as o: o.write( head ) if len(tail)>0: o.write( s20.decrypt(tail) ) return len(head)+len(tail),tm if __name__ == '__main__': iv = generate_iv() # 8 bytes: signature + time in 0.2 sec ticks since unix time 1330000000 s20 = Salsa20(b'123456789A123456789B123456789C12',iv) ciph = s20.encrypt(b'test') s20.iv_setup(iv) data = s20.decrypt(ciph) assert data == b'test' import sys if len(sys.argv)!=5 or sys.argv[1] not in ('d','e'): print( "Syntax: salsa20.py e|d key input-file output-file" ) sys.exit(0) key = (chr(len(sys.argv[2])) + sys.argv[2]).encode('utf-8',errors='backslashreplace') while len(key) <= _KEY_LEN: key += key try: t0 = time.time() if sys.argv[1]=='e': n = encrypt_file( key[:_KEY_LEN], sys.argv[3], sys.argv[4] ) else: # 'd' n,t = decrypt_file( key[:_KEY_LEN], sys.argv[3], sys.argv[4] ) print( time.strftime( "OK, was encrypted: %Y.%m.%d %H:%M:%S", time.localtime( t ) ) ) tt = max( time.time() - t0, 0.01 ) print("%.1f secs, %.1f kB/s" % (tt, n/1e3/tt)) # 120-130 kB/s # i3: 235 files, 188 memory except Exception as e: print( e )