195 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
			
		
		
	
	
			195 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
| #include "hssl.h"
 | |
| #include "hsocket.h"
 | |
| 
 | |
| #ifdef WITH_MBEDTLS
 | |
| 
 | |
| #include "mbedtls/version.h"
 | |
| #include "mbedtls/entropy.h"
 | |
| #include "mbedtls/ctr_drbg.h"
 | |
| #include "mbedtls/x509.h"
 | |
| #include "mbedtls/ssl.h"
 | |
| #include "mbedtls/net_sockets.h"
 | |
| #include "mbedtls/error.h"
 | |
| #include "mbedtls/debug.h"
 | |
| 
 | |
| #if defined(MBEDTLS_SSL_CACHE_C)
 | |
| #include "mbedtls/ssl_cache.h"
 | |
| #endif
 | |
| 
 | |
| #ifdef _MSC_VER
 | |
| //#pragma comment(lib, "libmbedtls.a")
 | |
| //#pragma comment(lib, "libmbedx509.a")
 | |
| //#pragma comment(lib, "libmbedcrypto.a")
 | |
| #endif
 | |
| 
 | |
| const char* hssl_backend() {
 | |
|     return "mbedtls";
 | |
| }
 | |
| 
 | |
| struct mbedtls_ctx {
 | |
|     mbedtls_entropy_context     entropy;
 | |
|     mbedtls_ctr_drbg_context    ctr_drbg;
 | |
|     mbedtls_ssl_config          conf;
 | |
|     mbedtls_x509_crt            cert;
 | |
|     mbedtls_pk_context          pkey;
 | |
| #if defined(MBEDTLS_SSL_CACHE_C)
 | |
|     mbedtls_ssl_cache_context   cache;
 | |
| #endif
 | |
| };
 | |
| 
 | |
| hssl_ctx_t hssl_ctx_new(hssl_ctx_opt_t* param) {
 | |
|     struct mbedtls_ctx* ctx = (struct mbedtls_ctx*)malloc(sizeof(struct mbedtls_ctx));
 | |
|     if (ctx == NULL) return NULL;
 | |
| 
 | |
|     mbedtls_ssl_config_init(&ctx->conf);
 | |
| #if defined(MBEDTLS_SSL_CACHE_C)
 | |
|     mbedtls_ssl_cache_init(&ctx->cache);
 | |
| #endif
 | |
|     mbedtls_x509_crt_init(&ctx->cert);
 | |
|     mbedtls_pk_init(&ctx->pkey);
 | |
|     mbedtls_entropy_init(&ctx->entropy);
 | |
|     mbedtls_ctr_drbg_init(&ctx->ctr_drbg);
 | |
| 
 | |
|     int mode = MBEDTLS_SSL_VERIFY_NONE;
 | |
|     int endpoint = MBEDTLS_SSL_IS_CLIENT;
 | |
|     bool check = false;
 | |
|     if (param) {
 | |
|         if (param->crt_file && *param->crt_file) {
 | |
|             if (mbedtls_x509_crt_parse_file(&ctx->cert, param->crt_file) != 0) {
 | |
|                 fprintf(stderr, "ssl crt_file error!\n");
 | |
|                 goto error;
 | |
|             }
 | |
|         }
 | |
|         if (param->key_file && *param->key_file) {
 | |
| #if MBEDTLS_VERSION_MAJOR >= 3
 | |
|             if (mbedtls_pk_parse_keyfile(&ctx->pkey, param->key_file, NULL, NULL, NULL) != 0) {
 | |
| #else
 | |
|             if (mbedtls_pk_parse_keyfile(&ctx->pkey, param->key_file, NULL) != 0) {
 | |
| #endif
 | |
|                 fprintf(stderr, "ssl key_file error!\n");
 | |
|                 goto error;
 | |
|             }
 | |
|             check = true;
 | |
|         }
 | |
|         if (param->verify_peer) {
 | |
|             mode = MBEDTLS_SSL_VERIFY_REQUIRED;
 | |
|         }
 | |
|         if (param->endpoint == HSSL_SERVER) {
 | |
|             endpoint = MBEDTLS_SSL_IS_SERVER;
 | |
|         }
 | |
|     }
 | |
|     mbedtls_ctr_drbg_seed(&ctx->ctr_drbg, mbedtls_entropy_func, &ctx->entropy, NULL, 0);
 | |
|     if (mbedtls_ssl_config_defaults(&ctx->conf, endpoint,
 | |
|         MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT) != 0) {
 | |
|         fprintf(stderr, "ssl config error!\n");
 | |
|         goto error;
 | |
|     }
 | |
|     mbedtls_ssl_conf_authmode(&ctx->conf, mode);
 | |
|     mbedtls_ssl_conf_rng(&ctx->conf, mbedtls_ctr_drbg_random, &ctx->ctr_drbg);
 | |
| 
 | |
| #if defined(MBEDTLS_SSL_CACHE_C)
 | |
|     mbedtls_ssl_conf_session_cache(&ctx->conf, &ctx->cache, mbedtls_ssl_cache_get, mbedtls_ssl_cache_set);
 | |
| #endif
 | |
| 
 | |
|     if (check) {
 | |
|         mbedtls_ssl_conf_ca_chain(&ctx->conf, ctx->cert.next, NULL);
 | |
|         if (mbedtls_ssl_conf_own_cert(&ctx->conf, &ctx->cert, &ctx->pkey) != 0) {
 | |
|             fprintf(stderr, "ssl key_file check failed!\n");
 | |
|             goto error;
 | |
|         }
 | |
|     }
 | |
|     return ctx;
 | |
| error:
 | |
|     free(ctx);
 | |
|     return NULL;
 | |
| }
 | |
| 
 | |
| void hssl_ctx_free(hssl_ctx_t ssl_ctx) {
 | |
|     if (!ssl_ctx) return;
 | |
|     struct mbedtls_ctx *mctx = (struct mbedtls_ctx *)ssl_ctx;
 | |
|     mbedtls_x509_crt_free(&mctx->cert);
 | |
|     mbedtls_pk_free(&mctx->pkey);
 | |
|     mbedtls_ssl_config_free(&mctx->conf);
 | |
| #if defined(MBEDTLS_SSL_CACHE_C)
 | |
|     mbedtls_ssl_cache_free(&mctx->cache);
 | |
| #endif
 | |
|     mbedtls_ctr_drbg_free(&mctx->ctr_drbg);
 | |
|     mbedtls_entropy_free(&mctx->entropy);
 | |
|     free(mctx);
 | |
| }
 | |
| 
 | |
| static int __mbedtls_net_send(void *ctx, const unsigned char *buf, size_t len) {
 | |
|     int fd = (intptr_t)ctx;
 | |
|     // int n = write(fd, buf, len);
 | |
|     int n = send(fd, (char*)(buf), (int)(len), 0);
 | |
|     if (n >= 0) return n;
 | |
|     return ((socket_errno() == EAGAIN || socket_errno() == EINPROGRESS) ? MBEDTLS_ERR_SSL_WANT_WRITE : -1);
 | |
| }
 | |
| 
 | |
| static int __mbedtls_net_recv(void *ctx, unsigned char *buf, size_t len) {
 | |
|     int fd = (intptr_t)ctx;
 | |
|     // int n = read(fd, buf, len);
 | |
|     int n = recv(fd, (char*)(buf), (int)(len), 0);
 | |
|     if (n >= 0) return n;
 | |
|     return ((socket_errno() == EAGAIN || socket_errno() == EINPROGRESS) ? MBEDTLS_ERR_SSL_WANT_READ : -1);
 | |
| }
 | |
| 
 | |
| hssl_t hssl_new(hssl_ctx_t ssl_ctx, int fd) {
 | |
|     struct mbedtls_ctx* mctx = (struct mbedtls_ctx*)ssl_ctx;
 | |
|     mbedtls_ssl_context* ssl = (mbedtls_ssl_context*)malloc(sizeof(mbedtls_ssl_context));
 | |
|     if (ssl == NULL) return NULL;
 | |
|     mbedtls_ssl_init(ssl);
 | |
|     mbedtls_ssl_setup(ssl, &mctx->conf);
 | |
|     mbedtls_ssl_set_bio(ssl, (void*)(intptr_t)fd, __mbedtls_net_send, __mbedtls_net_recv, NULL);
 | |
|     return ssl;
 | |
| }
 | |
| 
 | |
| void hssl_free(hssl_t ssl) {
 | |
|     if (ssl) {
 | |
|         mbedtls_ssl_free(ssl);
 | |
|         ssl = NULL;
 | |
|     }
 | |
| }
 | |
| 
 | |
| static int hssl_handshake(hssl_t ssl) {
 | |
|     int ret = mbedtls_ssl_handshake(ssl);
 | |
|     if (ret != 0) {
 | |
|         if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
 | |
|             return HSSL_WANT_READ;
 | |
|         }
 | |
|         else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
 | |
|             return HSSL_WANT_WRITE;
 | |
|         }
 | |
|     }
 | |
|     return ret;
 | |
| }
 | |
| 
 | |
| int hssl_accept(hssl_t ssl) {
 | |
|     return hssl_handshake(ssl);
 | |
| }
 | |
| 
 | |
| int hssl_connect(hssl_t ssl) {
 | |
|     return hssl_handshake(ssl);
 | |
| }
 | |
| 
 | |
| int hssl_read(hssl_t ssl, void* buf, int len) {
 | |
|     return mbedtls_ssl_read(ssl, buf, len);
 | |
| }
 | |
| 
 | |
| int hssl_write(hssl_t ssl, const void* buf, int len) {
 | |
|     return mbedtls_ssl_write(ssl, buf, len);
 | |
| }
 | |
| 
 | |
| int hssl_close(hssl_t ssl) {
 | |
|     return 0;
 | |
| }
 | |
| 
 | |
| int hssl_set_sni_hostname(hssl_t ssl, const char* hostname) {
 | |
| #ifdef MBEDTLS_X509_CRT_PARSE_C
 | |
|     mbedtls_ssl_set_hostname(ssl, hostname);
 | |
| #endif
 | |
|     return 0;
 | |
| }
 | |
| 
 | |
| #endif // WITH_MBEDTLS
 | 
