LCOV - code coverage report
Current view: top level - waltz/resolv - fd_res_msend.c (source / functions) Hit Total Coverage
Test: cov.lcov Lines: 0 232 0.0 %
Date: 2025-08-05 05:04:49 Functions: 0 5 0.0 %

          Line data    Source code
       1             : #define _GNU_SOURCE /* SYS_close */
       2             : #include <sys/socket.h>
       3             : #include <netinet/in.h>
       4             : #include <netinet/tcp.h>
       5             : #include <netdb.h>
       6             : #include <arpa/inet.h>
       7             : #include <stdint.h>
       8             : #include <string.h>
       9             : #include <poll.h>
      10             : #include <time.h>
      11             : #include <unistd.h>
      12             : #include <errno.h>
      13             : #include <pthread.h>
      14             : #include "syscall.h"
      15             : #include "fd_lookup.h"
      16             : #include "../../util/fd_util.h"
      17             : 
      18             : #pragma GCC diagnostic ignored "-Wconversion"
      19             : #pragma GCC diagnostic ignored "-Wsign-compare"
      20             : #pragma GCC diagnostic ignored "-Wsign-conversion"
      21             : 
      22             : static void
      23           0 : cleanup( struct pollfd * pfd ) {
      24           0 :   for( int i=0; pfd[i].fd >= -1; i++ ) {
      25           0 :     if( pfd[i].fd >= 0 ) {
      26           0 :       syscall( SYS_close, pfd[i].fd );
      27           0 :     }
      28           0 :   }
      29           0 : }
      30             : 
      31             : static ulong
      32           0 : mtime( void ) {
      33           0 :   struct timespec ts;
      34           0 :   if( clock_gettime( CLOCK_MONOTONIC, &ts ) < 0 && errno == ENOSYS )
      35           0 :     clock_gettime( CLOCK_REALTIME, &ts );
      36           0 :   return (ulong)ts.tv_sec * 1000
      37           0 :     + ts.tv_nsec / 1000000;
      38           0 : }
      39             : 
      40             : static int
      41             : start_tcp( struct pollfd * pfd,
      42             :            int             family,
      43             :            void const *    sa,
      44             :            socklen_t       sl,
      45             :            uchar const *   q,
      46           0 :            int             ql ) {
      47           0 :   struct msghdr mh = {
      48           0 :     .msg_name    = (void *)sa,
      49           0 :     .msg_namelen = sl,
      50           0 :     .msg_iovlen  = 2,
      51           0 :     .msg_iov = (struct iovec [2]){
      52           0 :       { .iov_base = (uint8_t[]){ ql>>8, ql }, .iov_len = 2 },
      53           0 :       { .iov_base = (void *)q, .iov_len = ql } },
      54           0 :     .msg_control    = NULL,
      55           0 :     .msg_controllen = 0,
      56           0 :     .msg_flags      = 0
      57           0 :   };
      58           0 :   int fd = socket( family, SOCK_STREAM|SOCK_CLOEXEC|SOCK_NONBLOCK, 0 );
      59           0 :   pfd->fd = fd;
      60           0 :   pfd->events = POLLOUT;
      61           0 :   if( !setsockopt( fd, IPPROTO_TCP, TCP_FASTOPEN_CONNECT,
      62           0 :       &(int){1}, sizeof(int) ) ) {
      63           0 :     int r = sendmsg( fd, &mh, MSG_FASTOPEN|MSG_NOSIGNAL );
      64           0 :     if( r == ql+2 ) pfd->events = POLLIN;
      65           0 :     if( r >= 0 ) return r;
      66           0 :     if( errno == EINPROGRESS ) return 0;
      67           0 :   }
      68           0 :   int r = connect( fd, sa, sl );
      69           0 :   if( !r || errno == EINPROGRESS ) return 0;
      70           0 :   close( fd );
      71           0 :   pfd->fd = -1;
      72           0 :   return -1;
      73           0 : }
      74             : 
      75             : static void
      76             : step_mh( struct msghdr * mh,
      77           0 :          size_t          n ) {
      78             :   /* Adjust iovec in msghdr to skip first n bytes. */
      79           0 :   while( mh->msg_iovlen && n >= mh->msg_iov->iov_len ) {
      80           0 :     n -= mh->msg_iov->iov_len;
      81           0 :     mh->msg_iov++;
      82           0 :     mh->msg_iovlen--;
      83           0 :   }
      84           0 :   if( !mh->msg_iovlen ) return;
      85           0 :   mh->msg_iov->iov_base = (char *)mh->msg_iov->iov_base + n;
      86           0 :   mh->msg_iov->iov_len -= n;
      87           0 : }
      88             : 
      89             : /* Internal contract for __res_msend[_rc]: asize must be >=512, nqueries
      90             :  * must be sufficiently small to be safe as VLA size. In practice it's
      91             :  * either 1 or 2, anyway. */
      92             : 
      93             : int
      94             : fd_res_msend_rc( int                     nqueries,
      95             :                  uchar const * const *   queries,
      96             :                  int const *             qlens,
      97             :                  uchar * const *         answers,
      98             :                  int *                   alens,
      99             :                  int                     asize,
     100           0 :                  fd_resolvconf_t const * conf ) {
     101           0 :   int fd;
     102           0 :   int servfail_retry = 0;
     103           0 :   union {
     104           0 :     struct sockaddr_in sin;
     105           0 :     struct sockaddr_in6 sin6;
     106           0 :   } sa = {0}, ns[MAXNS] = {0};
     107           0 :   socklen_t sl = sizeof sa.sin;
     108           0 :   int nns = 0;
     109           0 :   int family = AF_INET;
     110           0 :   int next;
     111           0 :   int i, j;
     112           0 :   struct pollfd pfd[nqueries+2];
     113           0 :   int qpos[nqueries], apos[nqueries];
     114           0 :   uchar alen_buf[nqueries][2];
     115             : 
     116           0 :   int timeout = 1000*conf->timeout;
     117           0 :   int attempts = conf->attempts;
     118             : 
     119           0 :   for( nns=0; nns<conf->nns; nns++ ) {
     120           0 :     const struct address *iplit = &conf->ns[nns];
     121           0 :     if( iplit->family == AF_INET ) {
     122           0 :       memcpy( &ns[nns].sin.sin_addr, iplit->addr, 4 );
     123           0 :       ns[nns].sin.sin_port = htons(53);
     124           0 :       ns[nns].sin.sin_family = AF_INET;
     125           0 :     } else {
     126           0 :       sl = sizeof sa.sin6;
     127           0 :       memcpy( &ns[nns].sin6.sin6_addr, iplit->addr, 16 );
     128           0 :       ns[nns].sin6.sin6_port = htons(53);
     129           0 :       ns[nns].sin6.sin6_scope_id = iplit->scopeid;
     130           0 :       ns[nns].sin6.sin6_family = family = AF_INET6;
     131           0 :     }
     132           0 :   }
     133             : 
     134             :   /* Get local address and open/bind a socket */
     135           0 :   fd = socket( family, SOCK_DGRAM|SOCK_CLOEXEC|SOCK_NONBLOCK, 0 );
     136             : 
     137             :   /* Handle case where system lacks IPv6 support */
     138           0 :   if( fd < 0 && family == AF_INET6 && errno == EAFNOSUPPORT ) {
     139           0 :     for( i=0; i<nns && conf->ns[nns].family == AF_INET6; i++ );
     140           0 :     if( i==nns ) {
     141           0 :       return -1;
     142           0 :     }
     143           0 :     fd = socket( AF_INET, SOCK_DGRAM|SOCK_CLOEXEC|SOCK_NONBLOCK, 0 );
     144           0 :     family = AF_INET;
     145           0 :     sl = sizeof sa.sin;
     146           0 :   }
     147             : 
     148             :   /* Convert any IPv4 addresses in a mixed environment to v4-mapped */
     149           0 :   if( fd >= 0 && family == AF_INET6 ) {
     150           0 :     setsockopt( fd, IPPROTO_IPV6, IPV6_V6ONLY, &(int){0}, sizeof 0 );
     151           0 :     for( i=0; i<nns; i++ ) {
     152           0 :       if( ns[i].sin.sin_family != AF_INET ) continue;
     153           0 :       memcpy( ns[i].sin6.sin6_addr.s6_addr+12, &ns[i].sin.sin_addr,             4 );
     154           0 :       memcpy( ns[i].sin6.sin6_addr.s6_addr,    "\0\0\0\0\0\0\0\0\0\0\xff\xff", 12 );
     155           0 :       ns[i].sin6.sin6_family = AF_INET6;
     156           0 :       ns[i].sin6.sin6_flowinfo = 0;
     157           0 :       ns[i].sin6.sin6_scope_id = 0;
     158           0 :     }
     159           0 :   }
     160             : 
     161           0 :   sa.sin.sin_family = family;
     162           0 :   if( fd < 0 || bind( fd, (void *)&sa, sl ) < 0 ) {
     163           0 :     if( fd >= 0 ) close( fd );
     164           0 :     return -1;
     165           0 :   }
     166             : 
     167             :   /* Past this point, there are no errors. Each individual query will
     168             :    * yield either no reply (indicated by zero length) or an answer
     169             :    * packet which is up to the caller to interpret. */
     170             : 
     171           0 :   for( i=0; i<nqueries; i++ ) pfd[i].fd = -1;
     172           0 :   pfd[nqueries].fd = fd;
     173           0 :   pfd[nqueries].events = POLLIN;
     174           0 :   pfd[nqueries+1].fd = -2;
     175             : 
     176           0 :   memset( alens, 0, sizeof *alens * nqueries );
     177             : 
     178           0 :   int retry_interval = timeout / attempts;
     179           0 :   next = 0;
     180           0 :   ulong t2 = mtime();
     181           0 :   ulong t0 = t2;
     182           0 :   ulong t1 = t2 - retry_interval;
     183             : 
     184           0 :   for( ; t2-t0 < timeout; t2=mtime() ) {
     185             :     /* This is the loop exit condition: that all queries
     186             :      * have an accepted answer. */
     187           0 :     for( i=0; i<nqueries && alens[i]>0; i++ );
     188           0 :     if( i==nqueries ) break;
     189             : 
     190           0 :     if( t2-t1 >= retry_interval ) {
     191             :       /* Query all configured namservers in parallel */
     192           0 :       for( i=0; i<nqueries; i++ )
     193           0 :         if( !alens[i] )
     194           0 :           for( j=0; j<nns; j++ )
     195           0 :             sendto( fd, queries[i],
     196           0 :               qlens[i], MSG_NOSIGNAL,
     197           0 :               (void *)&ns[j], sl );
     198           0 :       t1 = t2;
     199           0 :       servfail_retry = 2 * nqueries;
     200           0 :     }
     201             : 
     202             :     /* Wait for a response, or until time to retry */
     203           0 :     if( fd_syscall_poll( pfd, nqueries+1, t1+retry_interval-t2 ) <= 0 ) continue;
     204             : 
     205           0 :     while( next < nqueries ) {
     206           0 :       struct msghdr mh = {
     207           0 :         .msg_name = (void *)&sa,
     208           0 :         .msg_namelen = sl,
     209           0 :         .msg_iovlen = 1,
     210           0 :         .msg_iov = (struct iovec []){
     211           0 :           { .iov_base = (void *)answers[next],
     212           0 :             .iov_len = asize }
     213           0 :         },
     214           0 :         .msg_control    = NULL,
     215           0 :         .msg_controllen = 0,
     216           0 :         .msg_flags      = 0
     217           0 :       };
     218           0 :       int rlen = recvmsg( fd, &mh, 0 );
     219           0 :       if( rlen < 0 ) break;
     220             : 
     221             :       /* Ignore non-identifiable packets */
     222           0 :       if( rlen < 4 ) continue;
     223             : 
     224             :       /* Ignore replies from addresses we didn't send to */
     225           0 :       for( j=0; j<nns && memcmp( ns+j, &sa, sl ); j++ );
     226           0 :       if( j==nns ) continue;
     227             : 
     228             :       /* Find which query this answer goes with, if any */
     229           0 :       for( i=next; i<nqueries && (
     230           0 :         answers[next][0] != queries[i][0] ||
     231           0 :         answers[next][1] != queries[i][1] ); i++ );
     232           0 :       if( i==nqueries ) continue;
     233           0 :       if( alens[i]    ) continue;
     234             : 
     235             :       /* Only accept positive or negative responses;
     236             :        * retry immediately on server failure, and ignore
     237             :        * all other codes such as refusal. */
     238           0 :       switch( answers[next][3] & 15 ) {
     239           0 :       case 0:
     240           0 :       case 3:
     241           0 :         break;
     242           0 :       case 2:
     243           0 :         if( servfail_retry && servfail_retry-- )
     244           0 :           sendto( fd, queries[i], qlens[i], MSG_NOSIGNAL, (void *)&ns[j], sl );
     245           0 :         __attribute__((fallthrough));
     246           0 :       default:
     247           0 :         continue;
     248           0 :       }
     249             : 
     250             :       /* Store answer in the right slot, or update next
     251             :        * available temp slot if it's already in place. */
     252           0 :       alens[i] = rlen;
     253           0 :       if( i == next )
     254           0 :         for( ; next<nqueries && alens[next]; next++ );
     255           0 :       else
     256           0 :         memcpy( answers[i], answers[next], rlen );
     257             : 
     258             :       /* Ignore further UDP if all slots full or TCP-mode */
     259           0 :       if( next == nqueries ) pfd[nqueries].events = 0;
     260             : 
     261             :       /* If answer is truncated (TC bit), fallback to TCP */
     262           0 :       if( (answers[i][2] & 2) || (mh.msg_flags & MSG_TRUNC) ) {
     263           0 :         alens[i] = -1;
     264           0 :         int r = start_tcp( pfd+i, family, ns+j, sl, queries[i], qlens[i] );
     265           0 :         if( r >= 0 ) {
     266           0 :           qpos[i] = r;
     267           0 :           apos[i] = 0;
     268           0 :         }
     269           0 :         continue;
     270           0 :       }
     271           0 :     }
     272             : 
     273           0 :     for( i=0; i<nqueries; i++ ) if( pfd[i].revents & POLLOUT ) {
     274           0 :       struct msghdr mh = {
     275           0 :         .msg_iovlen = 2,
     276           0 :         .msg_iov = (struct iovec [2]){
     277           0 :           { .iov_base = (uint8_t[]){ qlens[i]>>8, qlens[i] }, .iov_len = 2 },
     278           0 :           { .iov_base = (void *)queries[i], .iov_len = qlens[i] } },
     279           0 :         .msg_control    = NULL,
     280           0 :         .msg_controllen = 0,
     281           0 :         .msg_flags      = 0
     282           0 :       };
     283           0 :       step_mh( &mh, qpos[i] );
     284           0 :       int r = sendmsg( pfd[i].fd, &mh, MSG_NOSIGNAL );
     285           0 :       if( r < 0 ) goto out;
     286           0 :       qpos[i] += r;
     287           0 :       if( qpos[i] == qlens[i]+2 )
     288           0 :         pfd[i].events = POLLIN;
     289           0 :     }
     290             : 
     291           0 :     for( i=0; i<nqueries; i++ ) if( pfd[i].revents & POLLIN ) {
     292           0 :       struct msghdr mh = {
     293           0 :         .msg_iovlen = 2,
     294           0 :         .msg_iov = (struct iovec [2]){
     295           0 :           { .iov_base = alen_buf[i], .iov_len = 2 },
     296           0 :           { .iov_base = answers[i], .iov_len = asize } },
     297           0 :         .msg_control    = NULL,
     298           0 :         .msg_controllen = 0,
     299           0 :         .msg_flags      = 0
     300           0 :       };
     301           0 :       step_mh( &mh, apos[i] );
     302           0 :       int r = recvmsg( pfd[i].fd, &mh, 0 );
     303           0 :       if( r <= 0 ) goto out;
     304           0 :       apos[i] += r;
     305           0 :       if( apos[i] < 2 ) continue;
     306           0 :       int alen = alen_buf[i][0]*256 + alen_buf[i][1];
     307           0 :       if( alen < 13 ) goto out;
     308           0 :       if( apos[i] < alen+2 && apos[i] < asize+2 )
     309           0 :         continue;
     310           0 :       int rcode = answers[i][3] & 15;
     311           0 :       if( rcode != 0 && rcode != 3 )
     312           0 :         goto out;
     313             : 
     314             :       /* Storing the length here commits the accepted answer.
     315             :          Immediately close TCP socket so as not to consume
     316             :          resources we no longer need. */
     317           0 :       alens[i] = alen;
     318           0 :       syscall( SYS_close, pfd[i].fd );
     319           0 :       pfd[i].fd = -1;
     320           0 :     }
     321           0 :   }
     322           0 : out:
     323           0 :   cleanup( pfd );
     324             : 
     325             :   /* Disregard any incomplete TCP results */
     326           0 :   for( i=0; i<nqueries; i++ ) if( alens[i]<0 ) alens[i] = 0;
     327             : 
     328           0 :   return 0;
     329           0 : }

Generated by: LCOV version 1.14