Annotation of lucent/sys/src/boot/pc/bootp.c, revision 1.1.1.1

1.1       root        1: #include "u.h"
                      2: #include "lib.h"
                      3: #include "mem.h"
                      4: #include "dat.h"
                      5: #include "fns.h"
                      6: #include "io.h"
                      7: 
                      8: #include "ip.h"
                      9: 
                     10: static ushort tftpport = 5000;
                     11: static int Id = 1;
                     12: static Netaddr myaddr;
                     13: static Netaddr server;
                     14: 
                     15: typedef struct {
                     16:        uchar   header[4];
                     17:        uchar   data[Segsize];
                     18: } Tftp;
                     19: static Tftp tftpb;
                     20: 
                     21: static void
                     22: hnputs(uchar *ptr, ushort val)
                     23: {
                     24:        ptr[0] = val>>8;
                     25:        ptr[1] = val;
                     26: }
                     27: 
                     28: static void
                     29: hnputl(uchar *ptr, ulong val)
                     30: {
                     31:        ptr[0] = val>>24;
                     32:        ptr[1] = val>>16;
                     33:        ptr[2] = val>>8;
                     34:        ptr[3] = val;
                     35: }
                     36: 
                     37: static ulong
                     38: nhgetl(uchar *ptr)
                     39: {
                     40:        return ((ptr[0]<<24) | (ptr[1]<<16) | (ptr[2]<<8) | ptr[3]);
                     41: }
                     42: 
                     43: static ushort
                     44: nhgets(uchar *ptr)
                     45: {
                     46:        return ((ptr[0]<<8) | ptr[1]);
                     47: }
                     48: 
                     49: static short   endian  = 1;
                     50: static char*   aendian = (char*)&endian;
                     51: #define        LITTLE  *aendian
                     52: 
                     53: static ushort
                     54: ptcl_csum(void *a, int len)
                     55: {
                     56:        uchar *addr;
                     57:        ulong t1, t2;
                     58:        ulong losum, hisum, mdsum, x;
                     59: 
                     60:        addr = a;
                     61:        losum = 0;
                     62:        hisum = 0;
                     63:        mdsum = 0;
                     64: 
                     65:        x = 0;
                     66:        if((ulong)addr & 1) {
                     67:                if(len) {
                     68:                        hisum += addr[0];
                     69:                        len--;
                     70:                        addr++;
                     71:                }
                     72:                x = 1;
                     73:        }
                     74:        while(len >= 16) {
                     75:                t1 = *(ushort*)(addr+0);
                     76:                t2 = *(ushort*)(addr+2);        mdsum += t1;
                     77:                t1 = *(ushort*)(addr+4);        mdsum += t2;
                     78:                t2 = *(ushort*)(addr+6);        mdsum += t1;
                     79:                t1 = *(ushort*)(addr+8);        mdsum += t2;
                     80:                t2 = *(ushort*)(addr+10);       mdsum += t1;
                     81:                t1 = *(ushort*)(addr+12);       mdsum += t2;
                     82:                t2 = *(ushort*)(addr+14);       mdsum += t1;
                     83:                mdsum += t2;
                     84:                len -= 16;
                     85:                addr += 16;
                     86:        }
                     87:        while(len >= 2) {
                     88:                mdsum += *(ushort*)addr;
                     89:                len -= 2;
                     90:                addr += 2;
                     91:        }
                     92:        if(x) {
                     93:                if(len)
                     94:                        losum += addr[0];
                     95:                if(LITTLE)
                     96:                        losum += mdsum;
                     97:                else
                     98:                        hisum += mdsum;
                     99:        } else {
                    100:                if(len)
                    101:                        hisum += addr[0];
                    102:                if(LITTLE)
                    103:                        hisum += mdsum;
                    104:                else
                    105:                        losum += mdsum;
                    106:        }
                    107: 
                    108:        losum += hisum >> 8;
                    109:        losum += (hisum & 0xff) << 8;
                    110:        while(hisum = losum>>16)
                    111:                losum = hisum + (losum & 0xffff);
                    112: 
                    113:        return ~losum;
                    114: }
                    115: 
                    116: static ushort
                    117: ip_csum(uchar *addr)
                    118: {
                    119:        int len;
                    120:        ulong sum = 0;
                    121: 
                    122:        len = (addr[0]&0xf)<<2;
                    123: 
                    124:        while(len > 0) {
                    125:                sum += addr[0]<<8 | addr[1] ;
                    126:                len -= 2;
                    127:                addr += 2;
                    128:        }
                    129: 
                    130:        sum = (sum & 0xffff) + (sum >> 16);
                    131:        sum = (sum & 0xffff) + (sum >> 16);
                    132:        return (sum^0xffff);
                    133: }
                    134: 
                    135: static void
                    136: udpsend(int ctlrno, Netaddr *a, void *data, int dlen)
                    137: {
                    138:        Udphdr *uh;
                    139:        Etherhdr *ip;
                    140:        Etherpkt pkt;
                    141:        int len, ptcllen;
                    142: 
                    143: 
                    144:        uh = (Udphdr*)&pkt;
                    145: 
                    146:        memset(uh, 0, sizeof(Etherpkt));
                    147:        memmove(uh->udpcksum+sizeof(uh->udpcksum), data, dlen);
                    148: 
                    149:        /*
                    150:         * UDP portion
                    151:         */
                    152:        ptcllen = dlen + (UDP_HDRSIZE-UDP_PHDRSIZE);
                    153:        uh->ttl = 0;
                    154:        uh->udpproto = IP_UDPPROTO;
                    155:        uh->frag[0] = 0;
                    156:        uh->frag[1] = 0;
                    157:        hnputs(uh->udpplen, ptcllen);
                    158:        hnputl(uh->udpsrc, myaddr.ip);
                    159:        hnputs(uh->udpsport, myaddr.port);
                    160:        hnputl(uh->udpdst, a->ip);
                    161:        hnputs(uh->udpdport, a->port);
                    162:        hnputs(uh->udplen, ptcllen);
                    163:        uh->udpcksum[0] = 0;
                    164:        uh->udpcksum[1] = 0;
                    165:        dlen = (dlen+1)&~1;
                    166:        hnputs(uh->udpcksum, ptcl_csum(&uh->ttl, dlen+UDP_HDRSIZE));
                    167: 
                    168:        /*
                    169:         * IP portion
                    170:         */
                    171:        ip = (Etherhdr*)&pkt;
                    172:        len = sizeof(Udphdr)+dlen;
                    173:        ip->vihl = IP_VER|IP_HLEN;
                    174:        ip->tos = 0;
                    175:        ip->ttl = 255;
                    176:        hnputs(ip->length, len-ETHER_HDR);
                    177:        hnputs(ip->id, Id++);
                    178:        ip->frag[0] = 0;
                    179:        ip->frag[1] = 0;
                    180:        ip->cksum[0] = 0;
                    181:        ip->cksum[1] = 0;
                    182:        hnputs(ip->cksum, ip_csum(&ip->vihl));
                    183: 
                    184:        /*
                    185:         * Ethernet MAC portion
                    186:         */
                    187:        hnputs(ip->type, ET_IP);
                    188:        memmove(ip->d, a->ea, sizeof(ip->d));
                    189: 
                    190:        ethertxpkt(ctlrno, &pkt, len, Timeout);
                    191: }
                    192: 
                    193: static void
                    194: nak(int ctlrno, Netaddr *a, int code, char *msg, int report)
                    195: {
                    196:        int n;
                    197:        char buf[128];
                    198: 
                    199:        buf[0] = 0;
                    200:        buf[1] = Tftp_ERROR;
                    201:        buf[2] = 0;
                    202:        buf[3] = code;
                    203:        strcpy(buf+4, msg);
                    204:        n = strlen(msg) + 4 + 1;
                    205:        udpsend(ctlrno, a, buf, n);
                    206:        if(report)
                    207:                print("\ntftp: error(%d): %s\n", code, msg);
                    208: }
                    209: 
                    210: static void
                    211: dumpbytes(uchar *p, int n)
                    212: {
                    213:        while(n-- > 0)
                    214:                print("%2.2ux ", *p++);
                    215:        print("\n");
                    216: }
                    217: 
                    218: static int
                    219: udprecv(int ctlrno, Netaddr *a, void *data, int dlen)
                    220: {
                    221:        int n, len;
                    222:        ushort csm;
                    223:        Udphdr *h;
                    224:        ulong addr;
                    225:        Etherpkt pkt;
                    226: 
                    227:        for(;;) {
                    228:                n = etherrxpkt(ctlrno, &pkt, Timeout);
                    229:                if(n <= 0)
                    230:                        return 0;
                    231: 
                    232:                h = (Udphdr*)&pkt;
                    233:                if(nhgets(h->type) != ET_IP)
                    234:                        continue;
                    235: 
                    236:                if(ip_csum(&h->vihl)) {
                    237:                        print("ip chksum error\n");
                    238:                        continue;
                    239:                }
                    240:                if(h->vihl != (IP_VER|IP_HLEN)) {
                    241:                        print("ip bad vers/hlen\n");
                    242:                        continue;
                    243:                }
                    244: 
                    245:                if(h->udpproto != IP_UDPPROTO)
                    246:                        continue;
                    247: 
                    248:                h->ttl = 0;
                    249:                len = nhgets(h->udplen);
                    250:                hnputs(h->udpplen, len);
                    251: 
                    252:                if(nhgets(h->udpcksum)) {
                    253:                        csm = ptcl_csum(&h->ttl, len+UDP_PHDRSIZE);
                    254:                        if(csm != 0) {
                    255:                                print("udp chksum error csum #%4lux len %d\n", csm, n);
                    256:                                dumpbytes((uchar*)&pkt, n < 64 ? n : 64);
                    257:                                break;
                    258:                        }
                    259:                }
                    260: 
                    261:                if(a->port != 0 && nhgets(h->udpsport) != a->port)
                    262:                        continue;
                    263: 
                    264:                addr = nhgetl(h->udpsrc);
                    265:                if(a->ip != Bcastip && addr != a->ip)
                    266:                        continue;
                    267: 
                    268:                len -= UDP_HDRSIZE-UDP_PHDRSIZE;
                    269:                if(len > dlen) {
                    270:                        print("udp: packet too big\n");
                    271:                        continue;
                    272:                }
                    273: 
                    274:                memmove(data, h->udpcksum+sizeof(h->udpcksum), len);
                    275:                a->ip = addr;
                    276:                a->port = nhgets(h->udpsport);
                    277:                memmove(a->ea, pkt.s, sizeof(a->ea));
                    278: 
                    279:                return len;
                    280:        }
                    281: 
                    282:        return 0;
                    283: }
                    284: 
                    285: static int tftpblockno;
                    286: 
                    287: static int
                    288: tftpopen(int ctlrno, Netaddr *a, char *name, Tftp *tftp)
                    289: {
                    290:        int i, len, rlen;
                    291:        char buf[Segsize+2];
                    292: 
                    293:        buf[0] = 0;
                    294:        buf[1] = Tftp_READ;
                    295:        len = sprint(buf+2, "%s", name) + 2;
                    296:        len += sprint(buf+len+1, "octet") + 2;
                    297: 
                    298:        for(i = 0; i < 5; i++){
                    299:                udpsend(ctlrno, a, buf, len);
                    300:                a->port = 0;
                    301:                if((rlen = udprecv(ctlrno, a, tftp, sizeof(Tftp))) < sizeof(tftp->header))
                    302:                        continue;
                    303: 
                    304:                switch((tftp->header[0]<<8)|tftp->header[1]){
                    305: 
                    306:                case Tftp_ERROR:
                    307:                        print("tftpopen: error (%d): %s\n",
                    308:                                (tftp->header[2]<<8)|tftp->header[3], tftp->data);
                    309:                        return -1;
                    310: 
                    311:                case Tftp_DATA:
                    312:                        tftpblockno = 1;
                    313:                        len = (tftp->header[2]<<8)|tftp->header[3];
                    314:                        if(len != tftpblockno){
                    315:                                print("tftpopen: block error: %d\n", len);
                    316:                                nak(ctlrno, a, 1, "block error", 0);
                    317:                                return -1;
                    318:                        }
                    319:                        return rlen-sizeof(tftp->header);
                    320:                }
                    321:        }
                    322: 
                    323:        print("tftpopen: failed to connect to server\n");
                    324:        return -1;
                    325: }
                    326: 
                    327: static int
                    328: tftpread(int ctlrno, Netaddr *a, Tftp *tftp, int dlen)
                    329: {
                    330:        int blockno, len;
                    331:        uchar buf[4];
                    332: 
                    333:        buf[0] = 0;
                    334:        buf[1] = Tftp_ACK;
                    335:        buf[2] = tftpblockno>>8;
                    336:        buf[3] = tftpblockno;
                    337:        tftpblockno++;
                    338: 
                    339:        dlen += sizeof(tftp->header);
                    340: 
                    341: buggery:
                    342:        udpsend(ctlrno, a, buf, sizeof(buf));
                    343: 
                    344:        if((len = udprecv(ctlrno, a, tftp, dlen)) != dlen){
                    345:                print("tftpread: %d != %d\n", len, dlen);
                    346:                nak(ctlrno, a, 2, "short read", 0);
                    347:        }
                    348: 
                    349:        blockno = (tftp->header[2]<<8)|tftp->header[3];
                    350:        if(blockno != tftpblockno){
                    351:                print("tftpread: block error: %d, expected %d\n", blockno, tftpblockno);
                    352: 
                    353:                if(blockno == tftpblockno-1)
                    354:                        goto buggery;
                    355:                nak(ctlrno, a, 1, "block error", 0);
                    356: 
                    357:                return -1;
                    358:        }
                    359: 
                    360:        return len-sizeof(tftp->header);
                    361: }
                    362: 
                    363: static void
                    364: tftpclose(int ctlrno, Netaddr *a, uchar code, char *msg)
                    365: {
                    366:        nak(ctlrno, a, code, msg, 0);
                    367: }
                    368: 
                    369: int
                    370: bootp(int ctlrno, char *file)
                    371: {
                    372:        Bootp req, rep;
                    373:        int i, dlen, segsize, text, data, bss, total;
                    374:        uchar *ea, *addr, *p;
                    375:        ulong entry;
                    376:        Exec *exec;
                    377:        char name[128], *filename, *sysname;
                    378: 
                    379:        if((ea = etheraddr(ctlrno)) == 0){
                    380:                print("invalid ctlrno %d\n", ctlrno);
                    381:                return -1;
                    382:        }
                    383: 
                    384:        filename = 0;
                    385:        sysname = 0;
                    386:        if(file && *file){
                    387:                strcpy(name, file);
                    388:                if(filename = strchr(name, ':')){
                    389:                        if(filename != name && *(filename-1) != '\\'){
                    390:                                sysname = name;
                    391:                                *filename++ = 0;
                    392:                        }
                    393:                }
                    394:                else
                    395:                        filename = name;
                    396:        }
                    397:                
                    398: 
                    399:        memset(&req, 0, sizeof(req));
                    400:        req.op = Bootrequest;
                    401:        req.htype = 1;                  /* ethernet */
                    402:        req.hlen = Eaddrlen;            /* ethernet */
                    403:        memmove(req.chaddr, ea, Eaddrlen);
                    404: 
                    405:        myaddr.ip = 0;
                    406:        myaddr.port = BPportsrc;
                    407:        memmove(myaddr.ea, ea, sizeof(myaddr.ea));
                    408: 
                    409:        for(i = 0; i < 10; i++) {
                    410:                server.ip = Bcastip;
                    411:                server.port = BPportdst;
                    412:                memset(server.ea, 0xFF, sizeof(server.ea));
                    413:                udpsend(ctlrno, &server, &req, sizeof(req));
                    414:                if(udprecv(ctlrno, &server, &rep, sizeof(rep)) <= 0)
                    415:                        continue;
                    416:                if(memcmp(req.chaddr, rep.chaddr, Eaddrlen))
                    417:                        continue;
                    418:                if(rep.htype != 1 || rep.hlen != Eaddrlen)
                    419:                        continue;
                    420:                if(sysname == 0 || strcmp(sysname, rep.sname) == 0)
                    421:                        break;
                    422:        }
                    423:        if(i >= 10) {
                    424:                print("bootp timed out\n");
                    425:                return -1;
                    426:        }
                    427: 
                    428:        if(filename == 0 || *filename == 0)
                    429:                filename = rep.file;
                    430:        if(rep.sname[0] != '\0')
                    431:                print("%s:%s\n", rep.sname, filename);
                    432: 
                    433:        myaddr.ip = nhgetl(rep.yiaddr);
                    434:        myaddr.port = tftpport++;
                    435:        server.port = TFTPport;
                    436: 
                    437:        if((dlen = tftpopen(ctlrno, &server, filename, &tftpb)) < 0)
                    438:                return -1;
                    439: 
                    440:        exec = (Exec*)(tftpb.data);
                    441:        if(dlen < sizeof(Exec) || GLLONG(exec->magic) != I_MAGIC){
                    442:                nak(ctlrno, &server, 0, "bad magic number", 1);
                    443:                return -1;
                    444:        }
                    445:        text = GLLONG(exec->text);
                    446:        data = GLLONG(exec->data);
                    447:        bss = GLLONG(exec->bss);
                    448:        total = text+data+bss;
                    449:        entry = GLLONG(exec->entry);
                    450:        print("%d", text);
                    451: 
                    452:        addr = (uchar*)PADDR(entry);
                    453:        p = tftpb.data+sizeof(Exec);
                    454:        dlen -= sizeof(Exec);
                    455:        segsize = text;
                    456:        for(;;){
                    457:                if(dlen == 0){
                    458:                        if((dlen = tftpread(ctlrno, &server, &tftpb, sizeof(tftpb.data))) < 0)
                    459:                                return -1;
                    460:                        p = tftpb.data;
                    461:                }
                    462:                if(segsize <= dlen)
                    463:                        i = segsize;
                    464:                else
                    465:                        i = dlen;
                    466:                memmove(addr, p, i);
                    467: 
                    468:                addr += i;
                    469:                p += i;
                    470:                segsize -= i;
                    471:                dlen -= i;
                    472: 
                    473:                if(segsize <= 0){
                    474:                        if(data == 0)
                    475:                                break;
                    476:                        print("+%d", data);
                    477:                        segsize = data;
                    478:                        data = 0;
                    479:                        addr = (uchar*)PGROUND((ulong)addr);
                    480:                }
                    481:        }
                    482:        tftpclose(ctlrno, &server, 3, "ok");
                    483:        print("+%d=%d\n", bss, total);
                    484:        print("entry: 0x%lux\n", entry);
                    485: 
                    486:        (*(void(*)(void))(PADDR(entry)))();
                    487:        
                    488:        return 0;
                    489: }

unix.superglobalmegacorp.com

This archive runs on limited infrastructure. Preserving old code on modern bandwidth. Automated agents are requested to crawl responsibly.