# ECDH椭圆双曲线（比DH快10倍的密钥交换）算法简介和封装

前面有几篇blog就提到我有计划支持使用ECDH密钥交换。近期也是抽空把以前的DH密钥交换跨平台适配从[atgateway](https://github.com/atframework/atsf4g-co/tree/master/atframework/service/atgateway)抽离出来，而后接入了[ECDH](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)流程。

## 背景

对[DH](https://en.wikipedia.org/wiki/Diffie%E2%80%93Hellman_key_exchange)和[ECDH](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)算法的具体原理这里不做具体介绍了，可以点击链接看。[DH](https://en.wikipedia.org/wiki/Diffie%E2%80%93Hellman_key_exchange)和[ECDH](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)的主要的作用就是在通信双方发送一些公有参数，保留私有参数，而后通过一系列计算双方都能够得到一个一致的结果。而这个运算的逆运算复杂度过高，在有限时间内不可解（至少量子计算机问世以前不可解），以保证密钥安全性。除了维基百科外，我还看到篇文章图画的很好看的：<http://andrea.corbellini.name/2015/05/30/elliptic-curve-cryptography-ecdh-and-ecdsa/> 。而[DH](https://en.wikipedia.org/wiki/Diffie%E2%80%93Hellman_key_exchange)和[ECDH](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)得区别简单来说就是，前者使用了一个大素数和两个随机数，而后者使用了[ECC](https://en.wikipedia.org/wiki/Elliptic-curve_cryptography)算法和两个随机点。

实际应用中，有些加密算法的密钥碰撞计算难度反而比破解[DH](https://en.wikipedia.org/wiki/Diffie%E2%80%93Hellman_key_exchange)和[ECDH](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)要容易（比如[atgateway](https://github.com/atframework/atsf4g-co/tree/master/atframework/service/atgateway)支持的[XXTEA](https://en.wikipedia.org/wiki/XXTEA)算法，这个算法很简单所以也非常高效）。所以有些工程实践中会每隔一段时间再走一次密钥交换流程来更换密钥。

## ECDH和DH

使用[ECDH](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)做密钥交换得时候你可能也会看到[ECDHE](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)这个词，这个多出来的E的意思是指每次公钥都随机生成。因为像HTTPS里那种是可以从证书文件里取静态公钥的。我封装的接口其实也是每次都随机生成公钥。

### 适配和接入

我这里选择了按[RFC 4492](https://tools.ietf.org/html/rfc4492)进行接入，因为我主要适配的两个库（[openssl](https://www.openssl.org/)和[mbedtls](https://tls.mbed.org/)）共同支持的似乎只有[RFC 4492](https://tools.ietf.org/html/rfc4492)。[openssl](https://www.openssl.org/)自不用说了，[mbedtls](https://tls.mbed.org/)主要是为了如果需要用到Android或者iOS上的话，用[mbedtls](https://tls.mbed.org/)比较容易一些。接入的过程中主要有三个问题。

第一个是[openssl](https://www.openssl.org/)的1.0.1版本支持的算法较少，而高版本比较多一些，而且[openssl](https://www.openssl.org/)本身是可以裁剪算法的。这个用宏来判定就行了，比较easy。

第二个是[openssl](https://www.openssl.org/)和[mbedtls](https://tls.mbed.org/)都有自己内部的ID和名称，并且不一样，传参数都是用内部ID的，然后输出再做转换。解决方法就是手夯了一个映射表，和之前搞\[crypto\_cipher]\[10]得方法一样。

第三就是[openssl](https://www.openssl.org/)的1.0.2和1.1.0的结构大不一样了，1.1.0版本的接口也更加严谨。这个比较麻烦，找了好久才找到1.1.0的这两部分代码实现。现在的做法是按1.1.0的方法封装了一些接口给1.0.1/1.0.2使用。比如这种:

```cpp
/**
 * @see crypto/dh/dh_lib.c in openssl 1.1.x
 */
static inline void DH_get0_key(const DH *dh, const BIGNUM **pub_key, const BIGNUM **priv_key) {
    if (pub_key != NULL) *pub_key = dh->pub_key;
    if (priv_key != NULL) *priv_key = dh->priv_key;
}
```

比较好的消息是[DH](https://en.wikipedia.org/wiki/Diffie%E2%80%93Hellman_key_exchange)和[ECDH](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)的流程基本一致，只是传输内容不同。但是[openssl](https://www.openssl.org/)不像[mbedtls](https://tls.mbed.org/)，没有良好的接口封装，里面密钥交换的细节实现得到[openssl](https://www.openssl.org/)的源码里去找。并且流程比较长，而且[openssl](https://www.openssl.org/)的实现不太好，有很多冗余的拷贝操作。实际上我接入的时候是对它的代码流程有些许优化话，主要还是减少不必要的拷贝。[openssl](https://www.openssl.org/)1.0.2的[ECDH](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)和[ECDSA](https://en.wikipedia.org/wiki/Elliptic-curve_cryptography)流程代码在ssl/s3\_srvr.c和ssl/s3\_clnt.c里，而1.1.0版本的相关流程代码在ssl/statem/statem\_clnt.c和ssl/statem/statem\_srvr.c里。版本适配代码大多在ssl/ssl\_locl.h和ssl/t1\_lib.c里。

唉openssl的文档和代码找起来真是蛋疼，比如[ECDHE](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman)的服务器下发系数的[openssl](https://www.openssl.org/)1.0.2的代码简化一下就是这样（不简化的话太长了）：

```cpp
if (((group = EC_KEY_get0_group(ecdh)) == NULL) ||
    (EC_KEY_get0_public_key(ecdh) == NULL) ||
    (EC_KEY_get0_private_key(ecdh) == NULL)) {
    SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE, ERR_R_ECDH_LIB);
    goto err;
}

if ((curve_id =
        tls1_ec_nid2curve_id(EC_GROUP_get_curve_name(group)))
    == 0) {
    SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,
            SSL_R_UNSUPPORTED_ELLIPTIC_CURVE);
    goto err;
}
encodedlen = EC_POINT_point2oct(group,
    EC_KEY_get0_public_key(ecdh),
    POINT_CONVERSION_UNCOMPRESSED,
    NULL, 0, NULL);

encodedPoint = (unsigned char *)
    OPENSSL_malloc(encodedlen * sizeof(unsigned char));
bn_ctx = BN_CTX_new();
if ((encodedPoint == NULL) || (bn_ctx == NULL)) {
    SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,
            ERR_R_MALLOC_FAILURE);
    goto err;
}

encodedlen = EC_POINT_point2oct(group,
    EC_KEY_get0_public_key(ecdh),
    POINT_CONVERSION_UNCOMPRESSED,
    encodedPoint, encodedlen, bn_ctx);

if (encodedlen == 0) {
    SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE, ERR_R_ECDH_LIB);
    goto err;
}

BN_CTX_free(bn_ctx);
bn_ctx = NULL;

// ... 然后是dump的代码 ..., p是输出缓冲区的起始指针
*p = NAMED_CURVE_TYPE;
p += 1;
*p = 0;
p += 1;
*p = curve_id;
p += 1;
*p = encodedlen;
p += 1;
memcpy((unsigned char *)p,
        (unsigned char *)encodedPoint, encodedlen);
OPENSSL_free(encodedPoint);
encodedPoint = NULL;
p += encodedlen;
```

相比之下[mbedtls](https://tls.mbed.org/)的简洁多了，都不用裁剪:

```cpp
int mbedtls_ecdh_make_params( mbedtls_ecdh_context *ctx, size_t *olen,
                      unsigned char *buf, size_t blen,
                      int (*f_rng)(void *, unsigned char *, size_t),
                      void *p_rng )
{
    int ret;
    size_t grp_len, pt_len;

    if( ctx == NULL || ctx->grp.pbits == 0 )
        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );

    if( ( ret = mbedtls_ecdh_gen_public( &ctx->grp, &ctx->d, &ctx->Q, f_rng, p_rng ) )
                != 0 )
        return( ret );

    if( ( ret = mbedtls_ecp_tls_write_group( &ctx->grp, &grp_len, buf, blen ) )
                != 0 )
        return( ret );

    buf += grp_len;
    blen -= grp_len;

    if( ( ret = mbedtls_ecp_tls_write_point( &ctx->grp, &ctx->Q, ctx->point_format,
                                     &pt_len, buf, blen ) ) != 0 )
        return( ret );

    *olen = grp_len + pt_len;
    return( 0 );
}

int mbedtls_ecp_tls_write_group( const mbedtls_ecp_group *grp, size_t *olen,
                         unsigned char *buf, size_t blen )
{
    const mbedtls_ecp_curve_info *curve_info;

    if( ( curve_info = mbedtls_ecp_curve_info_from_grp_id( grp->id ) ) == NULL )
        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );

    /*
     * We are going to write 3 bytes (see below)
     */
    *olen = 3;
    if( blen < *olen )
        return( MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL );

    /*
     * First byte is curve_type, always named_curve
     */
    *buf++ = MBEDTLS_ECP_TLS_NAMED_CURVE;

    /*
     * Next two bytes are the namedcurve value
     */
    buf[0] = curve_info->tls_id >> 8;
    buf[1] = curve_info->tls_id & 0xFF;

    return( 0 );
}

int mbedtls_ecp_tls_write_point( const mbedtls_ecp_group *grp, const mbedtls_ecp_point *pt,
                         int format, size_t *olen,
                         unsigned char *buf, size_t blen )
{
    int ret;

    /*
     * buffer length must be at least one, for our length byte
     */
    if( blen < 1 )
        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );

    if( ( ret = mbedtls_ecp_point_write_binary( grp, pt, format,
                    olen, buf + 1, blen - 1) ) != 0 )
        return( ret );

    /*
     * write length to the first byte and update total length
     */
    buf[0] = (unsigned char) *olen;
    ++*olen;

    return( 0 );
}
```

另外我只接入了密钥交换的流程，像它们其实有更高级的SSL/TLS接口，还包含验证流程、加解密流程、握手的cookie等等。我希望提供的是一个个单独可拆开用的组件，所以这里只接入了密钥交换。像加解密就封装到了\[crypto\_cipher]\[10]里。而像[atgateway](https://github.com/atframework/atsf4g-co/tree/master/atframework/service/atgateway)有自己的验证流程，并不像标准TLS/SSL那样走Hash。

### 交互流程差异

| 步骤      | DH                                    | ECDH                                                              |
| ------- | ------------------------------------- | ----------------------------------------------------------------- |
| 初始化     | 加载DH参数（主要是一个大素数P和系数G），由DH参数决定密钥长度     | 加载双曲线（[RFC 4492](https://tools.ietf.org/html/rfc4492)），由双曲线决定密钥长度 |
| 服务器下发系数 | 下发DH算法的P（大素数）、G、GY(G^Y mod P)，保留私有数据Y | 下发双曲线算法group和公钥点Q，保留私钥点d                                          |
| 服务器下发内容 | 2字节P长度，P,  2字节G长度，G,2字节GY长度，GY        | 1字节类型(3)，2字节双曲线ID，1字节Q长度,Q                                        |
| 客户端读取系数 | 读入P、G,记录GY为远端公钥                       | 读入双曲线算法group，记录Q为远端公钥                                             |
| 客户端创建公钥 | 随机出X，计算并上传GX(G^X mod P)，保留私有数据X       | 生成和上传公钥点Qp，保留私钥点z                                                 |
| 客户端上传内容 | 2字节GX长度，GX                            | 1字节Qp长度,Qp                                                        |
| 客户端计算密钥 | 根据P、G、GY、X计算出密钥                       | 根据group、Q、z计算出密钥                                                  |
| 服务器计算密钥 | 根据P、G、GX、Y计算出密钥                       | 根据group、Qp、d计算出密钥                                                 |

## API

封装了接口以后，现在的接口就非常简单了，和[mbedtls](https://tls.mbed.org/)的流程很像。比如下面DH算法的（截取自单元测试）

```cpp
// 客户端共享配置
util::crypto::dh cli_dh;

// 服务器共享配置(保存DH系数等)
util::crypto::dh svr_dh;

// 服务器 - init: 读取DHParam的PEM文件
{
    util::crypto::dh::shared_context::ptr_t svr_shctx = util::crypto::dh::shared_context::create();

    std::string dir;
    CASE_EXPECT_TRUE(util::file_system::dirname(__FILE__, 0, dir, 2));
    dir += util::file_system::DIRECTORY_SEPARATOR;
    dir += "resource";
    dir += util::file_system::DIRECTORY_SEPARATOR;
    dir += "test-dhparam.pem";
    // 前面只是找到PEM文件路径
    CASE_EXPECT_EQ(0, svr_shctx->init(dir.c_str()));
    CASE_EXPECT_EQ(0, svr_dh.init(svr_shctx));
}

// 客户端 - init: 设置成DH模式
{
    util::crypto::dh::shared_context::ptr_t cli_shctx = util::crypto::dh::shared_context::create();
    CASE_EXPECT_EQ(0, cli_shctx->init(util::crypto::dh::method_t::EN_CDT_DH));
    CASE_EXPECT_EQ(0, cli_dh.init(cli_shctx));
}

std::vector<unsigned char> switch_params; // 服务器下发的数据
std::vector<unsigned char> switch_public; // 客户端上传的数据
std::vector<unsigned char> cli_secret; // 保存客户端计算的密钥
std::vector<unsigned char> svr_secret; // 保存服务器计算的密钥

// step 1 - 服务器: 计算密钥对，输出DH参数和公钥
CASE_EXPECT_EQ(0, svr_dh.make_params(switch_params));

// step 2 - 客户端: 读取服务器的公钥和DH参数
CASE_EXPECT_EQ(0, cli_dh.read_params(switch_params.data(), switch_params.size()));

// step 3 - 客户端: 计算密钥对，输出公钥
CASE_EXPECT_EQ(0, cli_dh.make_public(switch_public));

// step 4 - 客户端: 计算协商结果
CASE_EXPECT_EQ(0, cli_dh.calc_secret(cli_secret));

// step 5 - 服务器: 读取客户端的公钥
CASE_EXPECT_EQ(0, svr_dh.read_public(switch_public.data(), switch_public.size()));

// step 6 - 服务器: 计算协商结果
CASE_EXPECT_EQ(0, svr_dh.calc_secret(svr_secret));

// DH流程结束，后面是检查两边结果一致
CASE_EXPECT_EQ(cli_secret.size(), svr_secret.size());
if (cli_secret.size() == svr_secret.size()) {
    CASE_EXPECT_EQ(0, memcmp(cli_secret.data(), svr_secret.data(), svr_secret.size()));
}
```

还有ECDH算法的（截取自单元测试）

```cpp
// 枚举所有的加密算法
const std::vector<std::string> &all_curves = util::crypto::dh::get_all_curve_names();

for (size_t curve_idx = 0; curve_idx < all_curves.size(); ++curve_idx) {

    // 客户端共享配置
    util::crypto::dh cli_dh;

    // 服务器共享配置(保存椭圆算法ID等)
    util::crypto::dh svr_dh;

    // 服务器 - init: 读取指定的椭圆曲线
    {
        util::crypto::dh::shared_context::ptr_t svr_shctx = util::crypto::dh::shared_context::create();
        CASE_EXPECT_EQ(0, svr_shctx->init(all_curves[curve_idx].c_str()));
        CASE_EXPECT_EQ(0, svr_dh.init(svr_shctx));
    }

    // 客户端 - init: 设置为ECDH模式
    {
        util::crypto::dh::shared_context::ptr_t cli_shctx = util::crypto::dh::shared_context::create();
        CASE_EXPECT_EQ(0, cli_shctx->init(util::crypto::dh::method_t::EN_CDT_ECDH));
        CASE_EXPECT_EQ(0, cli_dh.init(cli_shctx));
    }

    std::vector<unsigned char> switch_params;  // 服务器下发的数据
    std::vector<unsigned char> switch_public;  // 客户端上传的数据
    std::vector<unsigned char> cli_secret;  // 保存客户端计算的密钥
    std::vector<unsigned char> svr_secret;  // 保存服务器计算的密钥

    // step 1 - 服务器: 计算密钥对，输出双曲线ID和公钥点
    CASE_EXPECT_EQ(0, svr_dh.make_params(switch_params));

    // step 2 - 客户端: 读取服务器的双曲线ID和公钥点
    CASE_EXPECT_EQ(0, cli_dh.read_params(switch_params.data(), switch_params.size()));

    // step 3 - 客户端: 计算密钥对，输出公钥点
    CASE_EXPECT_EQ(0, cli_dh.make_public(switch_public));

    // step 4 - 客户端: 计算协商结果
    CASE_EXPECT_EQ(0, cli_dh.calc_secret(cli_secret));

    // step 5 - 服务器: 读取客户端的公钥点
    CASE_EXPECT_EQ(0, svr_dh.read_public(switch_public.data(), switch_public.size()));

    // step 6 - 服务器: 计算协商结果
    CASE_EXPECT_EQ(0, svr_dh.calc_secret(svr_secret));

    // ECDH流程结束，后面是检查两边结果一致
    CASE_EXPECT_EQ(cli_secret.size(), svr_secret.size());
    if (cli_secret.size() == svr_secret.size()) {
        CASE_EXPECT_EQ(0, memcmp(cli_secret.data(), svr_secret.data(), svr_secret.size()));
    }
}
```

## 单元测试、性能和valgrind报告

抽离出来以后就比较方便加单元测试了。单元测试的时候发现openssl底层会分配一些全局数据，导致valgrind报still reachable。但是实测了多次不同次数的加解密后的报告的块数和总大小都一样。所以这个就可以忽略了。

然后就是报告的结果，我直接从CI里复制出来了一部分。Windows和Linux里的结果一致：

```
[ RUN      ] crypto_dh.dh
[ RUNNING  ] Test DH algorithm 32 times, key len 128 bits. 
[       OK ] crypto_dh.dh (228.725 ms)
[ RUN      ] crypto_dh.ecdh
[ RUNNING  ] Test ECDH algorithm 16 times for 8 curves done. 
[ RUNNING  ]   Fastest => ecdh:secp224r1 cost 0.60075ms(avg.) key len 224 bits. 
[ RUNNING  ]   Slowest => ecdh:secp384r1 cost 3.86687ms(avg.) key len 384 bits. 
[       OK ] crypto_dh.ecdh (221.049 ms)
```

这就是标题里说的**快10倍**的来源。可以看到，耗时最短的双曲线，密钥长度是224，平均每次耗时是0.6ms(client+server)。而DH用的是1024bits的DHParameter，密钥长度128，平均耗时是228.725/32=7.14ms 。而最慢的双曲线性能也是两倍多，而平均值是221.049/16/18=1.73ms，性能也是4倍多。

封装接口的时候，其实我是选取了[mbedtls](https://tls.mbed.org/)目前支持的全部算法。其实openssl本身支持的更多，并且可以裁剪算法。最后适配完之后，看了下单元测试，[openssl](https://www.openssl.org/)1.0.1最多只支持其中8种算法，而1.0.2以上版本是支持全部16种的。交叉测试也pass了（用[openssl](https://www.openssl.org/)版本做服务端，[mbedtls](https://tls.mbed.org/)做客户端），但是不太好写构建工程所以交叉测试是手动进行的（用[atgateway](https://github.com/atframework/atsf4g-co/tree/master/atframework/service/atgateway)编的服务器，然后里面的sample换另一个库编）。

## 最后

现在所有代码都放在 <https://github.com/atframework/atframe_utils/blob/master/include/algorithm/crypto_cipher.h> 和 <https://github.com/atframework/atframe_utils/blob/master/src/algorithm/crypto_dh.cpp> 封装完以后，万一其他什么模块要用，就方便很多了。

其实对于Android，官方推荐的做法是从Java接口封因此鞥过来，而iOS也有自带的加解密库，适配这两个可能是能够最大幅度减小包大小的方法。但是我感觉裁剪良好的情况下，用mbedtls也没大多少。所以暂时还没有去研究这两个平台独立的接入方式。

我也不是专门搞密码学的，所以理解上可能还会有些偏差，欢迎大家拍砖指正交流。

\[10]: <https://github.com/atframework/atframe_utils/blob/master/include/algorithm/crypto_cipher.h>
