198 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
			
		
		
	
	
			198 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
| #include "hssl.h"
 | |
| 
 | |
| #ifdef WITH_GNUTLS
 | |
| 
 | |
| #include "gnutls/gnutls.h"
 | |
| 
 | |
| const char* hssl_backend() {
 | |
|     return "gnutls";
 | |
| }
 | |
| 
 | |
| typedef gnutls_certificate_credentials_t gnutls_ctx_t;
 | |
| 
 | |
| hssl_ctx_t hssl_ctx_new(hssl_ctx_opt_t* param) {
 | |
|     static int s_initialized = 0;
 | |
|     if (s_initialized == 0) {
 | |
|         gnutls_global_init();
 | |
|         s_initialized = 1;
 | |
|     }
 | |
| 
 | |
|     gnutls_ctx_t ctx;
 | |
|     const char* crt_file = NULL;
 | |
|     const char* key_file = NULL;
 | |
|     const char* ca_file = NULL;
 | |
|     const char* ca_path = NULL;
 | |
| 
 | |
|     int ret = gnutls_certificate_allocate_credentials(&ctx);
 | |
|     if (ret != GNUTLS_E_SUCCESS) {
 | |
|         return NULL;
 | |
|     }
 | |
| 
 | |
|     if (param) {
 | |
|         if (param->crt_file && *param->crt_file) {
 | |
|             crt_file = param->crt_file;
 | |
|         }
 | |
|         if (param->key_file && *param->key_file) {
 | |
|             key_file = param->key_file;
 | |
|         }
 | |
|         if (param->ca_file && *param->ca_file) {
 | |
|             ca_file = param->ca_file;
 | |
|         }
 | |
|         if (param->ca_path && *param->ca_path) {
 | |
|             ca_path = param->ca_path;
 | |
|         }
 | |
| 
 | |
|         if (ca_file) {
 | |
|             ret = gnutls_certificate_set_x509_trust_file(ctx, ca_file, GNUTLS_X509_FMT_PEM);
 | |
|             if (ret < 0) {
 | |
|                 fprintf(stderr, "ssl ca_file failed!\n");
 | |
|                 goto error;
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         if (ca_path) {
 | |
|             ret = gnutls_certificate_set_x509_trust_dir(ctx, ca_path, GNUTLS_X509_FMT_PEM);
 | |
|             if (ret < 0) {
 | |
|                 fprintf(stderr, "ssl ca_path failed!\n");
 | |
|                 goto error;
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         if (crt_file && key_file) {
 | |
|             ret = gnutls_certificate_set_x509_key_file(ctx, crt_file, key_file, GNUTLS_X509_FMT_PEM);
 | |
|             if (ret != GNUTLS_E_SUCCESS) {
 | |
|                 fprintf(stderr, "ssl crt_file/key_file error!\n");
 | |
|                 goto error;
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         if (param->verify_peer && !ca_file && !ca_path) {
 | |
|             gnutls_certificate_set_x509_system_trust(ctx);
 | |
|         }
 | |
|     }
 | |
|     return ctx;
 | |
| error:
 | |
|     gnutls_certificate_free_credentials(ctx);
 | |
|     return NULL;
 | |
| }
 | |
| 
 | |
| void hssl_ctx_free(hssl_ctx_t ssl_ctx) {
 | |
|     if (!ssl_ctx) return;
 | |
|     gnutls_ctx_t ctx = (gnutls_ctx_t)ssl_ctx;
 | |
|     gnutls_certificate_free_credentials(ctx);
 | |
| }
 | |
| 
 | |
| typedef struct gnutls_s {
 | |
|     gnutls_session_t session;
 | |
|     gnutls_ctx_t     ctx;
 | |
|     int              fd;
 | |
| } gnutls_t;
 | |
| 
 | |
| hssl_t hssl_new(hssl_ctx_t ssl_ctx, int fd) {
 | |
|     gnutls_t* gnutls = (gnutls_t*)malloc(sizeof(gnutls_t));
 | |
|     if (gnutls == NULL) return NULL;
 | |
|     gnutls->session = NULL;
 | |
|     gnutls->ctx = (gnutls_ctx_t)ssl_ctx;
 | |
|     gnutls->fd = fd;
 | |
|     return (hssl_t)gnutls;
 | |
| }
 | |
| 
 | |
| static int hssl_init(hssl_t ssl, int endpoint) {
 | |
|     if (ssl == NULL) return HSSL_ERROR;
 | |
|     gnutls_t* gnutls = (gnutls_t*)ssl;
 | |
|     if (gnutls->session == NULL) {
 | |
|         gnutls_init(&gnutls->session, endpoint);
 | |
|         gnutls_priority_set_direct(gnutls->session, "NORMAL", NULL);
 | |
|         gnutls_credentials_set(gnutls->session, GNUTLS_CRD_CERTIFICATE, gnutls->ctx);
 | |
|         gnutls_transport_set_ptr(gnutls->session, (gnutls_transport_ptr_t)(ptrdiff_t)gnutls->fd);
 | |
|     }
 | |
|     return HSSL_OK;
 | |
| }
 | |
| 
 | |
| void hssl_free(hssl_t ssl) {
 | |
|     if (ssl == NULL) return;
 | |
|     gnutls_t* gnutls = (gnutls_t*)ssl;
 | |
|     if (gnutls->session) {
 | |
|         gnutls_deinit(gnutls->session);
 | |
|         gnutls->session = NULL;
 | |
|     }
 | |
|     free(gnutls);
 | |
| }
 | |
| 
 | |
| static int hssl_handshake(hssl_t ssl) {
 | |
|     if (ssl == NULL) return HSSL_ERROR;
 | |
|     gnutls_t* gnutls = (gnutls_t*)ssl;
 | |
|     if (gnutls->session == NULL) return HSSL_ERROR;
 | |
|     int ret = 0;
 | |
|     while (1) {
 | |
|         ret = gnutls_handshake(gnutls->session);
 | |
|         if (ret == GNUTLS_E_SUCCESS) {
 | |
|             return HSSL_OK;
 | |
|         }
 | |
|         else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) {
 | |
|             return gnutls_record_get_direction(gnutls->session) == 0 ? HSSL_WANT_READ : HSSL_WANT_WRITE;
 | |
|         }
 | |
|         else if (gnutls_error_is_fatal(ret)) {
 | |
|             // fprintf(stderr, "gnutls_handshake failed: %s\n", gnutls_strerror(ret));
 | |
|             return HSSL_ERROR;
 | |
|         }
 | |
|     }
 | |
|     return HSSL_OK;
 | |
| }
 | |
| 
 | |
| int hssl_accept(hssl_t ssl) {
 | |
|     if (ssl == NULL) return HSSL_ERROR;
 | |
|     gnutls_t* gnutls = (gnutls_t*)ssl;
 | |
|     if (gnutls->session == NULL) {
 | |
|         hssl_init(ssl, GNUTLS_SERVER);
 | |
|     }
 | |
|     return hssl_handshake(ssl);
 | |
| }
 | |
| 
 | |
| int hssl_connect(hssl_t ssl) {
 | |
|     if (ssl == NULL) return HSSL_ERROR;
 | |
|     gnutls_t* gnutls = (gnutls_t*)ssl;
 | |
|     if (gnutls->session == NULL) {
 | |
|         hssl_init(ssl, GNUTLS_CLIENT);
 | |
|     }
 | |
|     return hssl_handshake(ssl);
 | |
| }
 | |
| 
 | |
| int hssl_read(hssl_t ssl, void* buf, int len) {
 | |
|     if (ssl == NULL) return HSSL_ERROR;
 | |
|     gnutls_t* gnutls = (gnutls_t*)ssl;
 | |
|     if (gnutls->session == NULL) return HSSL_ERROR;
 | |
|     int ret = 0;
 | |
|     while ((ret = gnutls_record_recv(gnutls->session, buf, len)) == GNUTLS_E_INTERRUPTED);
 | |
|     return ret;
 | |
| }
 | |
| 
 | |
| int hssl_write(hssl_t ssl, const void* buf, int len) {
 | |
|     if (ssl == NULL) return HSSL_ERROR;
 | |
|     gnutls_t* gnutls = (gnutls_t*)ssl;
 | |
|     if (gnutls->session == NULL) return HSSL_ERROR;
 | |
|     int ret = 0;
 | |
|     while ((ret = gnutls_record_send(gnutls->session, buf, len)) == GNUTLS_E_INTERRUPTED);
 | |
|     return ret;
 | |
| }
 | |
| 
 | |
| int hssl_close(hssl_t ssl) {
 | |
|     if (ssl == NULL) return HSSL_ERROR;
 | |
|     gnutls_t* gnutls = (gnutls_t*)ssl;
 | |
|     if (gnutls->session == NULL) return HSSL_ERROR;
 | |
|     gnutls_bye(gnutls->session, GNUTLS_SHUT_RDWR);
 | |
|     return HSSL_OK;
 | |
| }
 | |
| 
 | |
| int hssl_set_sni_hostname(hssl_t ssl, const char* hostname) {
 | |
|     if (ssl == NULL) return HSSL_ERROR;
 | |
|     gnutls_t* gnutls = (gnutls_t*)ssl;
 | |
|     if (gnutls->session == NULL) {
 | |
|         hssl_init(ssl, GNUTLS_CLIENT);
 | |
|     }
 | |
|     gnutls_server_name_set(gnutls->session, GNUTLS_NAME_DNS, hostname, strlen(hostname));
 | |
|     return 0;
 | |
| }
 | |
| 
 | |
| #endif // WITH_GNUTLS
 | 
