diff --git a/src/main/java/org/xbill/DNS/Zone.java b/src/main/java/org/xbill/DNS/Zone.java index 644b3430..3ae4f2dd 100644 --- a/src/main/java/org/xbill/DNS/Zone.java +++ b/src/main/java/org/xbill/DNS/Zone.java @@ -11,6 +11,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.TreeMap; +import java.util.concurrent.locks.ReentrantReadWriteLock; /** * A DNS Zone. This encapsulates all data related to a Zone, and provides convenient lookup methods. @@ -19,7 +20,9 @@ */ public class Zone implements Serializable { - private static final long serialVersionUID = -9220510891189510942L; + private transient ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); + private transient ReentrantReadWriteLock.ReadLock readLock = readWriteLock.readLock(); + private transient ReentrantReadWriteLock.WriteLock writeLock = readWriteLock.writeLock(); /** A primary zone */ public static final int PRIMARY = 1; @@ -41,21 +44,24 @@ class ZoneIterator implements Iterator { private boolean wantLastSOA; ZoneIterator(boolean axfr) { - synchronized (Zone.this) { + readLock.lock(); + try { zentries = data.entrySet().iterator(); - } - wantLastSOA = axfr; - RRset[] sets = allRRsets(originNode); - current = new RRset[sets.length]; - for (int i = 0, j = 2; i < sets.length; i++) { - int type = sets[i].getType(); - if (type == Type.SOA) { - current[0] = sets[i]; - } else if (type == Type.NS) { - current[1] = sets[i]; - } else { - current[j++] = sets[i]; + wantLastSOA = axfr; + RRset[] sets = allRRsets(originNode); + current = new RRset[sets.length]; + for (int i = 0, j = 2; i < sets.length; i++) { + int type = sets[i].getType(); + if (type == Type.SOA) { + current[0] = sets[i]; + } else if (type == Type.NS) { + current[1] = sets[i]; + } else { + current[j++] = sets[i]; + } } + } finally { + readLock.unlock(); } } @@ -99,6 +105,15 @@ public void remove() { } } + /** + * Sets the Reader Lock for this Zone instance. Only to be used for testing. + * + * @param lock The Reader Lock to set for this Zone instance. + */ + void setLock(ReentrantReadWriteLock.ReadLock lock) { + readLock = lock; + } + private void validate() throws IOException { originNode = exactName(origin); if (originNode == null) { @@ -137,20 +152,25 @@ private void maybeAddRecord(Record record) throws IOException { * @see Master */ public Zone(Name zone, String file) throws IOException { - data = new TreeMap<>(); + writeLock.lock(); + try { + data = new TreeMap<>(); - if (zone == null) { - throw new IllegalArgumentException("no zone name specified"); - } - try (Master m = new Master(file, zone)) { - Record record; + if (zone == null) { + throw new IllegalArgumentException("no zone name specified"); + } + try (Master m = new Master(file, zone)) { + Record record; - origin = zone; - while ((record = m.nextRecord()) != null) { - maybeAddRecord(record); + origin = zone; + while ((record = m.nextRecord()) != null) { + maybeAddRecord(record); + } } + validate(); + } finally { + writeLock.unlock(); } - validate(); } /** @@ -161,33 +181,41 @@ public Zone(Name zone, String file) throws IOException { * @see Master */ public Zone(Name zone, Record[] records) throws IOException { - data = new TreeMap<>(); + writeLock.lock(); + try { + data = new TreeMap<>(); - if (zone == null) { - throw new IllegalArgumentException("no zone name specified"); - } - origin = zone; - for (Record record : records) { - maybeAddRecord(record); + if (zone == null) { + throw new IllegalArgumentException("no zone name specified"); + } + origin = zone; + for (Record record : records) { + maybeAddRecord(record); + } + validate(); + } finally { + writeLock.unlock(); } - validate(); } private void fromXFR(ZoneTransferIn xfrin) throws IOException, ZoneTransferException { - synchronized (this) { + writeLock.lock(); + try { data = new TreeMap<>(); - } - origin = xfrin.getName(); - xfrin.run(); - if (!xfrin.isAXFR()) { - throw new IllegalArgumentException("zones can only be created from AXFRs"); - } + origin = xfrin.getName(); + xfrin.run(); + if (!xfrin.isAXFR()) { + throw new IllegalArgumentException("zones can only be created from AXFRs"); + } - for (Record record : xfrin.getAXFR()) { - maybeAddRecord(record); + for (Record record : xfrin.getAXFR()) { + maybeAddRecord(record); + } + validate(); + } finally { + writeLock.unlock(); } - validate(); } /** @@ -231,43 +259,59 @@ public int getDClass() { return DClass.IN; } - private synchronized Object exactName(Name name) { - return data.get(name); + private Object exactName(Name name) { + readLock.lock(); + try { + Object val = data.get(name); + return val; + } finally { + readLock.unlock(); + } } - private synchronized RRset[] allRRsets(Object types) { + private RRset[] allRRsets(Object types) { if (types instanceof List) { - @SuppressWarnings("unchecked") - List typelist = (List) types; - return typelist.toArray(new RRset[0]); + readLock.lock(); + try { + @SuppressWarnings("unchecked") + List typelist = (List) types; + return typelist.toArray(new RRset[0]); + } finally { + readLock.unlock(); + } } else { RRset set = (RRset) types; return new RRset[] {set}; } } - private synchronized RRset oneRRset(Object types, int type) { + private RRset oneRRset(Object types, int type) { if (type == Type.ANY) { throw new IllegalArgumentException("oneRRset(ANY)"); } - if (types instanceof List) { - @SuppressWarnings("unchecked") - List list = (List) types; - for (RRset set : list) { + readLock.lock(); + try { + if (types instanceof List) { + @SuppressWarnings("unchecked") + List list = (List) types; + for (RRset set : list) { + if (set.getType() == type) { + return set; + } + } + } else { + RRset set = (RRset) types; if (set.getType() == type) { return set; } } - } else { - RRset set = (RRset) types; - if (set.getType() == type) { - return set; - } + } finally { + readLock.unlock(); } return null; } - private synchronized RRset findRRset(Name name, int type) { + private RRset findRRset(Name name, int type) { Object types = exactName(name); if (types == null) { return null; @@ -275,68 +319,78 @@ private synchronized RRset findRRset(Name name, int type) { return oneRRset(types, type); } - private synchronized void addRRset(Name name, RRset rrset) { - if (!hasWild && name.isWild()) { - hasWild = true; - } - Object types = data.get(name); - if (types == null) { - data.put(name, rrset); - return; - } - int rtype = rrset.getType(); - if (types instanceof List) { - @SuppressWarnings("unchecked") - List list = (List) types; - for (int i = 0; i < list.size(); i++) { - RRset set = list.get(i); - if (set.getType() == rtype) { - list.set(i, rrset); - return; - } + private void addRRset(Name name, RRset rrset) { + writeLock.lock(); + try { + if (!hasWild && name.isWild()) { + hasWild = true; } - list.add(rrset); - } else { - RRset set = (RRset) types; - if (set.getType() == rtype) { + Object types = data.get(name); + if (types == null) { data.put(name, rrset); - } else { - LinkedList list = new LinkedList<>(); - list.add(set); + return; + } + int rtype = rrset.getType(); + if (types instanceof List) { + @SuppressWarnings("unchecked") + List list = (List) types; + for (int i = 0; i < list.size(); i++) { + RRset set = list.get(i); + if (set.getType() == rtype) { + list.set(i, rrset); + return; + } + } list.add(rrset); - data.put(name, list); + } else { + RRset set = (RRset) types; + if (set.getType() == rtype) { + data.put(name, rrset); + } else { + LinkedList list = new LinkedList<>(); + list.add(set); + list.add(rrset); + data.put(name, list); + } } + } finally { + writeLock.unlock(); } } - private synchronized void removeRRset(Name name, int type) { - Object types = data.get(name); - if (types == null) { - return; - } - if (types instanceof List) { - @SuppressWarnings("unchecked") - List list = (List) types; - for (int i = 0; i < list.size(); i++) { - RRset set = list.get(i); - if (set.getType() == type) { - list.remove(i); - if (list.isEmpty()) { - data.remove(name); + private void removeRRset(Name name, int type) { + writeLock.lock(); + try { + Object types = data.get(name); + if (types == null) { + return; + } + if (types instanceof List) { + @SuppressWarnings("unchecked") + List list = (List) types; + for (int i = 0; i < list.size(); i++) { + RRset set = list.get(i); + if (set.getType() == type) { + list.remove(i); + if (list.isEmpty()) { + data.remove(name); + } + return; } + } + } else { + RRset set = (RRset) types; + if (set.getType() != type) { return; } + data.remove(name); } - } else { - RRset set = (RRset) types; - if (set.getType() != type) { - return; - } - data.remove(name); + } finally { + writeLock.unlock(); } } - private synchronized SetResponse lookup(Name name, int type) { + private SetResponse lookup(Name name, int type) { if (!name.subdomain(origin)) { return SetResponse.ofType(SetResponse.NXDOMAIN); } @@ -478,8 +532,13 @@ public RRset findExactMatch(Name name, int type) { * @see RRset */ public void addRRset(RRset rrset) { - Name name = rrset.getName(); - addRRset(name, rrset); + writeLock.lock(); + try { + Name name = rrset.getName(); + addRRset(name, rrset); + } finally { + writeLock.unlock(); + } } /** @@ -489,9 +548,10 @@ public void addRRset(RRset rrset) { * @see Record */ public void addRecord(T r) { - Name name = r.getName(); - int rtype = r.getRRsetType(); - synchronized (this) { + writeLock.lock(); + try { + Name name = r.getName(); + int rtype = r.getRRsetType(); RRset rrset = findRRset(name, rtype); if (rrset == null) { rrset = new RRset(r); @@ -499,6 +559,8 @@ public void addRecord(T r) { } else { rrset.addRR(r); } + } finally { + writeLock.unlock(); } } @@ -509,9 +571,10 @@ public void addRecord(T r) { * @see Record */ public void removeRecord(Record r) { - Name name = r.getName(); - int rtype = r.getRRsetType(); - synchronized (this) { + writeLock.lock(); + try { + Name name = r.getName(); + int rtype = r.getRRsetType(); RRset rrset = findRRset(name, rtype); if (rrset == null) { return; @@ -521,6 +584,8 @@ public void removeRecord(Record r) { } else { rrset.deleteRR(r); } + } finally { + writeLock.unlock(); } } @@ -538,7 +603,7 @@ public Iterator AXFR() { return new ZoneIterator(true); } - private void nodeToString(StringBuffer sb, Object node) { + private void nodeToString(StringBuilder sb, Object node) { RRset[] sets = allRRsets(node); for (RRset rrset : sets) { rrset.rrs().forEach(r -> sb.append(r).append('\n')); @@ -547,15 +612,20 @@ private void nodeToString(StringBuffer sb, Object node) { } /** Returns the contents of the Zone in master file format. */ - public synchronized String toMasterFile() { - StringBuffer sb = new StringBuffer(); - nodeToString(sb, originNode); - for (Map.Entry entry : data.entrySet()) { - if (!origin.equals(entry.getKey())) { - nodeToString(sb, entry.getValue()); + public String toMasterFile() { + readLock.lock(); + try { + StringBuilder sb = new StringBuilder(); + nodeToString(sb, originNode); + for (Map.Entry entry : data.entrySet()) { + if (!origin.equals(entry.getKey())) { + nodeToString(sb, entry.getValue()); + } } + return sb.toString(); + } finally { + readLock.unlock(); } - return sb.toString(); } /** Returns the contents of the Zone as a string (in master file format). */ diff --git a/src/test/java/org/xbill/DNS/ZoneTest.java b/src/test/java/org/xbill/DNS/ZoneTest.java index bd39e903..89211493 100644 --- a/src/test/java/org/xbill/DNS/ZoneTest.java +++ b/src/test/java/org/xbill/DNS/ZoneTest.java @@ -3,11 +3,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import java.io.IOException; import java.net.InetAddress; import java.util.Collections; import java.util.List; +import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.stream.Collectors; import java.util.stream.Stream; import org.junit.jupiter.api.Test; @@ -110,6 +114,16 @@ void wildNameAnyLookup() { resp.answers()); } + @Test + void testReadLocksAreAcquiredAndReleasedCorrectNumberOfTimes() { + Name testName = Name.fromConstantString("test.example."); + ReentrantReadWriteLock.ReadLock readLock = mock(ReentrantReadWriteLock.ReadLock.class); + ZONE.setLock(readLock); + SetResponse resp = ZONE.findRecords(testName, Type.ANY); + verify(readLock, times(5)).lock(); + verify(readLock, times(5)).unlock(); + } + private static List listOf(RRset... rrsets) { return Stream.of(rrsets).collect(Collectors.toList()); }