SM3密码杂凑算法是中国国家密码管理局2010年公布的中国商用密码杂凑算法标准。具体算法标准原始文本参见参考文献[1]。该算法于2012年发布为密码行业标准(GM/T 0004-2012),2016年发布为国家密码杂凑算法标准(GB/T 32905-2016)。
SM3适用于商用密码应用中的数字签名和验证,是在[SHA-256]基础上改进实现的一种算法,其安全性和SHA-256相当。SM3和MD5的迭代过程类似,也采用Merkle-Damgard结构。消息分组长度为512位,摘要值长度为256位。
整个算法的执行过程可以概括成四个步骤:消息填充、消息扩展、迭代压缩、输出结果。
消息填充
SM3的消息扩展步骤是以512位的数据分组作为输入的。因此,我们需要在一开始就把数据长度填充至512位的倍数。数据填充规则和MD5一样,具体步骤如下:
1、先填充一个“1”,后面加上k个“0”。其中k是满足(n+1+k) mod 512 = 448的最小正整数。
2、追加64位的数据长度(bit为单位,大端序存放1。观察算法标准原文附录A运算示例可以推知。)
消息扩展
SM3的迭代压缩步骤没有直接使用数据分组进行运算,而是使用这个步骤产生的132个消息字。(一个消息字的长度为32位/4个字节/8个16j进制数字)概括来说,先将一个512位数据分组划分为16个消息字,并且作为生成的132个消息字的前16个。再用这16个消息字递推生成剩余的116个消息字。
迭代压缩
SM3的迭代过程和MD5类似,也是Merkle-Damgard结构。但和MD5不同的是,SM3使用消息扩展得到的消息字进行运算。这个迭代过程可以用这幅图表示:
初值IV被放在A、B、C、D、E、F、G、H八个32位变量中,其具体数值参见参考文献[1]。整个算法中最核心、也最复杂的地方就在于压缩函数。压缩函数将这八个变量进行64轮相同的计算,一轮的计算过程如下图所示:
图中不同的数据流向用不同颜色的箭头表示。
最后,再将计算完成的A、B、C、D、E、F、G、H和原来的A、B、C、D、E、F、G、H分别进行异或,就是压缩函数的输出。这个输出再作为下一次调用压缩函数时的初值。依次类推,直到用完最后一组132个消息字为止。
输出结果
将得到的A、B、C、D、E、F、G、H八个变量拼接输出,就是SM3算法的输出。
JAVA实现SM3加密与校验
- 导入Maven依赖
<!-- https://mvnrepository.com/artifact/org.bouncycastle/bcprov-jdk15on -->
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15on</artifactId>
<version>1.70</version>
</dependency>
- SM3Utils
private static final String ENCODING = "UTF-8";
/**
* 加密
*
* @param src 明文
* @param key 密钥
* @return
* @throws Exception
*/
public static String encrypt(String src, String key) throws Exception {
return ByteUtils.toHexString(getEncryptByKey(src, key));
}
/**
* SM3加密方式之: 根据自定义密钥进行加密,返回加密后长度为32位的16进制字符串
*
* @param src 源数据
* @param key 密钥
* @return 加密后长度为32位的16进制字符串
* @throws Exception
*/
public static byte[] getEncryptByKey(String src, String key) throws Exception {
byte[] srcByte = src.getBytes(ENCODING);
byte[] keyByte = key.getBytes(ENCODING);
KeyParameter keyParameter = new KeyParameter(keyByte);
SM3Digest sm3 = new SM3Digest();
HMac hMac = new HMac(sm3);
hMac.init(keyParameter);
hMac.update(srcByte, 0, srcByte.length);
byte[] result = new byte[hMac.getMacSize()];
hMac.doFinal(result, 0);
return result;
}
/**
* 利用源数据+密钥校验与密文是否一致
*
* @param src 源数据
* @param key 密钥
* @param sm3HexStr 密文
* @return
* @throws Exception
*/
public static boolean verify(String src, String key, String sm3HexStr) throws Exception {
byte[] sm3HashCode = ByteUtils.fromHexString(sm3HexStr);
byte[] newHashCode = getEncryptByKey(src, key);
return Arrays.equals(newHashCode, sm3HashCode);
}
/**
* SM3加密方式之:不提供密钥的方式 SM3加密,返回加密后长度为64位的16进制字符串
*
* @param src 明文
* @return 加密后长度为64位的16进制字符串
*/
public static String encrypt(String src) {
return ByteUtils.toHexString(getEncryptBySrcByte(src.getBytes()));
}
/**
* 返回长度为32位的加密后的byte数组
*
* @param srcByte
* @return
*/
public static byte[] getEncryptBySrcByte(byte[] srcByte) {
SM3Digest sm3 = new SM3Digest();
sm3.update(srcByte, 0, srcByte.length);
byte[] encryptByte = new byte[sm3.getDigestSize()];
sm3.doFinal(encryptByte, 0);
return encryptByte;
}
/**
* 校验源数据与加密数据是否一致
*
* @param src 源数据
* @param sm3HexStr 16进制的加密数据
* @return
* @throws Exception
*/
public static boolean verify(String src, String sm3HexStr) throws Exception {
byte[] sm3HashCode = ByteUtils.fromHexString(sm3HexStr);
byte[] newHashCode = getEncryptBySrcByte(src.getBytes(ENCODING));
return Arrays.equals(newHashCode, sm3HashCode);
}
public static void main(String[] args) throws Exception {
String srcStr = "今天天气很晴朗";
String key = "zjqzjq";
// ******************************自定义密钥加密及校验*****************************************
String hexStrByKey = SM3Utils.encrypt(srcStr, key);
System.out.println("带密钥加密后的密文:" + hexStrByKey);
System.out.println("明文(带密钥)与密文校验结果:" + SM3Utils.verify(srcStr, key, hexStrByKey));
// ******************************无密钥的加密及校验******************************************
String hexStrNoKey = SM3Utils.encrypt(srcStr);
System.out.println("不带密钥加密后的密文:" + hexStrNoKey);
System.out.println("明文(不带密钥)与密文校验结果:" + SM3Utils.verify(srcStr, hexStrNoKey));
}
注:SM3算法的实现
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.util.Arrays;
import ltd.snowland.utils.NumberTool;
import ltd.snowland.utils.StreamTool;
/**
* SM3杂凑算法实现
*
* @author Potato
*
*/
public class SM3 {
private static char[] hexDigits = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E',
'F' };
private static final String ivHexStr = "7380166f 4914b2b9 172442d7 da8a0600 a96f30bc 163138aa e38dee4d b0fb0e4e";
private static final BigInteger IV = new BigInteger(ivHexStr.replaceAll(" ", ""), 16);
private static final Integer Tj15 = Integer.valueOf("79cc4519", 16);
private static final Integer Tj63 = Integer.valueOf("7a879d8a", 16);
private static final byte[] FirstPadding = { (byte) 0x80 };
private static final byte[] ZeroPadding = { (byte) 0x00 };
private static int T(int j) {
if (j >= 0 && j <= 15) {
return Tj15.intValue();
} else if (j >= 16 && j <= 63) {
return Tj63.intValue();
} else {
throw new RuntimeException("data invalid");
}
}
private static Integer FF(Integer x, Integer y, Integer z, int j) {
if (j >= 0 && j <= 15) {
return Integer.valueOf(x.intValue() ^ y.intValue() ^ z.intValue());
} else if (j >= 16 && j <= 63) {
return Integer.valueOf(
(x.intValue() & y.intValue()) | (x.intValue() & z.intValue()) | (y.intValue() & z.intValue()));
} else {
throw new RuntimeException("data invalid");
}
}
private static Integer GG(Integer x, Integer y, Integer z, int j) {
if (j >= 0 && j <= 15) {
return Integer.valueOf(x.intValue() ^ y.intValue() ^ z.intValue());
} else if (j >= 16 && j <= 63) {
return Integer.valueOf((x.intValue() & y.intValue()) | (~x.intValue() & z.intValue()));
} else {
throw new RuntimeException("data invalid");
}
}
private static Integer P0(Integer x) {
return Integer
.valueOf(x.intValue() ^ Integer.rotateLeft(x.intValue(), 9) ^ Integer.rotateLeft(x.intValue(), 17));
}
private static Integer P1(Integer x) {
return Integer
.valueOf(x.intValue() ^ Integer.rotateLeft(x.intValue(), 15) ^ Integer.rotateLeft(x.intValue(), 23));
}
private static byte[] padding(byte[] source) throws IOException {
if (source.length >= 0x2000000000000000l) {
throw new RuntimeException("src data invalid.");
}
long l = source.length * 8;
long k = 448 - (l + 1) % 512;
if (k < 0) {
k = k + 512;
}
ByteArrayOutputStream baos = new ByteArrayOutputStream();
baos.write(source);
baos.write(FirstPadding);
long i = k - 7;
while (i > 0) {
baos.write(ZeroPadding);
i -= 8;
}
baos.write(long2bytes(l));
return baos.toByteArray();
}
private static byte[] long2bytes(long l) {
byte[] bytes = new byte[8];
for (int i = 0; i < 8; i++) {
bytes[i] = (byte) (l >>> ((7 - i) * 8));
}
return bytes;
}
public static byte[] hash(byte[] source) throws IOException {
byte[] m1 = padding(source);
int n = m1.length / (512 / 8);
byte[] b;
byte[] vi = IV.toByteArray();
byte[] vi1 = null;
for (int i = 0; i < n; i++) {
b = Arrays.copyOfRange(m1, i * 64, (i + 1) * 64);
vi1 = CF(vi, b);
vi = vi1;
}
return vi1;
}
public static byte[] hash(String source) throws Exception {
return hash(source.getBytes());
}
public static byte[] hash(File file) throws Exception {
if (file.exists()) {
InputStream inStream = new FileInputStream(file);
return hash(StreamTool.readInputStream2ByteArray(inStream));
} else {
throw new FileNotFoundException();
}
}
private static byte[] CF(byte[] vi, byte[] bi) throws IOException {
int a, b, c, d, e, f, g, h;
a = toInteger(vi, 0);
b = toInteger(vi, 1);
c = toInteger(vi, 2);
d = toInteger(vi, 3);
e = toInteger(vi, 4);
f = toInteger(vi, 5);
g = toInteger(vi, 6);
h = toInteger(vi, 7);
int[] w = new int[68];
int[] w1 = new int[64];
for (int i = 0; i < 16; i++) {
w[i] = toInteger(bi, i);
}
for (int j = 16; j < 68; j++) {
w[j] = P1(w[j - 16] ^ w[j - 9] ^ Integer.rotateLeft(w[j - 3], 15)) ^ Integer.rotateLeft(w[j - 13], 7)
^ w[j - 6];
}
for (int j = 0; j < 64; j++) {
w1[j] = w[j] ^ w[j + 4];
}
int ss1, ss2, tt1, tt2;
for (int j = 0; j < 64; j++) {
ss1 = Integer.rotateLeft(Integer.rotateLeft(a, 12) + e + Integer.rotateLeft(T(j), j), 7);
ss2 = ss1 ^ Integer.rotateLeft(a, 12);
tt1 = FF(a, b, c, j) + d + ss2 + w1[j];
tt2 = GG(e, f, g, j) + h + ss1 + w[j];
d = c;
c = Integer.rotateLeft(b, 9);
b = a;
a = tt1;
h = g;
g = Integer.rotateLeft(f, 19);
f = e;
e = P0(tt2);
}
byte[] v = toByteArray(a, b, c, d, e, f, g, h);
for (int i = 0; i < v.length; i++) {
v[i] = (byte) (v[i] ^ vi[i]);
}
return v;
}
private static int toInteger(byte[] source, int index) {
StringBuilder valueStr = new StringBuilder("");
for (int i = 0; i < 4; i++) {
valueStr.append(hexDigits[(byte) ((source[index * 4 + i] & 0xF0) >> 4)]);
valueStr.append(hexDigits[(byte) (source[index * 4 + i] & 0x0F)]);
}
return Long.valueOf(valueStr.toString(), 16).intValue();
}
private static byte[] toByteArray(int a, int b, int c, int d, int e, int f, int g, int h) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream(32);
baos.write(toByteArray(a));
baos.write(toByteArray(b));
baos.write(toByteArray(c));
baos.write(toByteArray(d));
baos.write(toByteArray(e));
baos.write(toByteArray(f));
baos.write(toByteArray(g));
baos.write(toByteArray(h));
return baos.toByteArray();
}
public static byte[] toByteArray(int i) {
byte[] byteArray = new byte[4];
byteArray[0] = (byte) (i >>> 24);
byteArray[1] = (byte) ((i & 0xFFFFFF) >>> 16);
byteArray[2] = (byte) ((i & 0xFFFF) >>> 8);
byteArray[3] = (byte) (i & 0xFF);
return byteArray;
}
private static String byteToHexString(byte b) {
int n = b;
if (n < 0)
n = 256 + n;
int d1 = n / 16;
int d2 = n % 16;
return "" + hexDigits[d1] + hexDigits[d2];
}
public static String byteArrayToHexString(byte[] b) {
StringBuffer resultSb = new StringBuffer();
for (int i = 0; i < b.length; i++) {
resultSb.append(byteToHexString(b[i]));
}
return resultSb.toString();
}
public static String hashHex(byte[] b) {
try {
return SM3.byteArrayToHexString(SM3.hash(b));
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
public static String hashHex(String str) {
return hashHex(str.getBytes());
}
public static void main(String[] args) throws IOException {
System.out.println(SM3.byteArrayToHexString(SM3.hash("abc".getBytes())));
}
}