Skip to content

Commit 359836b

Browse files
committed
Implement deep unaliasing, and use it in interface dispatch resolution
1 parent 4d21e33 commit 359836b

2 files changed

Lines changed: 163 additions & 53 deletions

File tree

go/extractor/extractor.go

Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,8 +1510,22 @@ func extractSpec(tw *trap.Writer, spec ast.Spec, parent trap.Label, idx int) {
15101510
// extractType extracts type information for `tp` and returns its associated label;
15111511
// types are only extracted once, so the second time `extractType` is invoked it simply returns the label
15121512
func extractType(tw *trap.Writer, tp types.Type) trap.Label {
1513-
lbl, exists := getTypeLabel(tw, tp)
1513+
return extractTypeWithFlags(tw, tp, false)
1514+
}
1515+
1516+
func extractTypeWithFlags(tw *trap.Writer, tp types.Type, transparentAliases bool) trap.Label {
1517+
lbl, exists := getTypeLabelWithFlags(tw, tp, transparentAliases)
15141518
if !exists {
1519+
if !transparentAliases {
1520+
// Ensure the (deep) underlying type is also extracted, so that it is
1521+
// possible to implement deepUnalias in QL.
1522+
// For example, if we had type A = int and type B = string, we would need
1523+
// to extract map[string]int so that deepUnalias(map[B]A) has a real member
1524+
// of @type to return.
1525+
//
1526+
// TODO: consider using a newtype to do this instead.
1527+
extractTypeWithFlags(tw, tp, true)
1528+
}
15151529
var kind int
15161530
switch tp := tp.(type) {
15171531
case *types.Basic:
@@ -1523,10 +1537,10 @@ func extractType(tw *trap.Writer, tp types.Type) trap.Label {
15231537
case *types.Array:
15241538
kind = dbscheme.ArrayType.Index()
15251539
dbscheme.ArrayLengthTable.Emit(tw, lbl, fmt.Sprintf("%d", tp.Len()))
1526-
extractElementType(tw, lbl, tp.Elem())
1540+
extractElementType(tw, lbl, tp.Elem(), transparentAliases)
15271541
case *types.Slice:
15281542
kind = dbscheme.SliceType.Index()
1529-
extractElementType(tw, lbl, tp.Elem())
1543+
extractElementType(tw, lbl, tp.Elem(), transparentAliases)
15301544
case *types.Struct:
15311545
kind = dbscheme.StructType.Index()
15321546
for i := 0; i < tp.NumFields(); i++ {
@@ -1546,12 +1560,12 @@ func extractType(tw *trap.Writer, tp types.Type) trap.Label {
15461560
if field.Embedded() {
15471561
name = ""
15481562
}
1549-
extractComponentType(tw, lbl, i, name, field.Type())
1563+
extractComponentType(tw, lbl, i, name, field.Type(), transparentAliases)
15501564
dbscheme.ComponentTagsTable.Emit(tw, lbl, i, tp.Tag(i))
15511565
}
15521566
case *types.Pointer:
15531567
kind = dbscheme.PointerType.Index()
1554-
extractBaseType(tw, lbl, tp.Elem())
1568+
extractBaseType(tw, lbl, tp.Elem(), transparentAliases)
15551569
case *types.Interface:
15561570
kind = dbscheme.InterfaceType.Index()
15571571
for i := 0; i < tp.NumMethods(); i++ {
@@ -1561,51 +1575,51 @@ func extractType(tw *trap.Writer, tp types.Type) trap.Label {
15611575
// not dealt with by `extractScopes`
15621576
extractMethod(tw, meth)
15631577

1564-
extractComponentType(tw, lbl, i, meth.Name(), meth.Type())
1578+
extractComponentType(tw, lbl, i, meth.Name(), meth.Type(), transparentAliases)
15651579
}
15661580
for i := 0; i < tp.NumEmbeddeds(); i++ {
15671581
component := tp.EmbeddedType(i)
15681582
if isNonUnionTypeSetLiteral(component) {
15691583
component = createUnionFromType(component)
15701584
}
1571-
extractComponentType(tw, lbl, -(i + 1), "", component)
1585+
extractComponentType(tw, lbl, -(i + 1), "", component, transparentAliases)
15721586
}
15731587
case *types.Tuple:
15741588
kind = dbscheme.TupleType.Index()
15751589
for i := 0; i < tp.Len(); i++ {
1576-
extractComponentType(tw, lbl, i, "", tp.At(i).Type())
1590+
extractComponentType(tw, lbl, i, "", tp.At(i).Type(), transparentAliases)
15771591
}
15781592
case *types.Signature:
15791593
kind = dbscheme.SignatureType.Index()
15801594
params, results := tp.Params(), tp.Results()
15811595
if params != nil {
15821596
for i := 0; i < params.Len(); i++ {
15831597
param := params.At(i)
1584-
extractComponentType(tw, lbl, i+1, "", param.Type())
1598+
extractComponentType(tw, lbl, i+1, "", param.Type(), transparentAliases)
15851599
}
15861600
}
15871601
if results != nil {
15881602
for i := 0; i < results.Len(); i++ {
15891603
result := results.At(i)
1590-
extractComponentType(tw, lbl, -(i + 1), "", result.Type())
1604+
extractComponentType(tw, lbl, -(i + 1), "", result.Type(), transparentAliases)
15911605
}
15921606
}
15931607
if tp.Variadic() {
15941608
dbscheme.VariadicTable.Emit(tw, lbl)
15951609
}
15961610
case *types.Map:
15971611
kind = dbscheme.MapType.Index()
1598-
extractKeyType(tw, lbl, tp.Key())
1599-
extractElementType(tw, lbl, tp.Elem())
1612+
extractKeyType(tw, lbl, tp.Key(), transparentAliases)
1613+
extractElementType(tw, lbl, tp.Elem(), transparentAliases)
16001614
case *types.Chan:
16011615
kind = dbscheme.ChanTypes[tp.Dir()].Index()
1602-
extractElementType(tw, lbl, tp.Elem())
1616+
extractElementType(tw, lbl, tp.Elem(), transparentAliases)
16031617
case *types.Named:
16041618
origintp := tp.Origin()
16051619
kind = dbscheme.NamedType.Index()
16061620
dbscheme.TypeNameTable.Emit(tw, lbl, origintp.Obj().Name())
16071621
underlying := origintp.Underlying()
1608-
extractUnderlyingType(tw, lbl, underlying)
1622+
extractUnderlyingType(tw, lbl, underlying, transparentAliases)
16091623
trackInstantiatedStructFields(tw, tp, origintp)
16101624

16111625
extractTypeObject(tw, lbl, origintp.Obj())
@@ -1638,14 +1652,18 @@ func extractType(tw *trap.Writer, tp types.Type) trap.Label {
16381652
if term.Tilde() {
16391653
tildeStr = "~"
16401654
}
1641-
extractComponentType(tw, lbl, i, tildeStr, term.Type())
1655+
extractComponentType(tw, lbl, i, tildeStr, term.Type(), transparentAliases)
16421656
}
16431657
case *types.Alias:
1644-
kind = dbscheme.TypeAlias.Index()
1645-
dbscheme.TypeNameTable.Emit(tw, lbl, tp.Obj().Name())
1646-
dbscheme.AliasRhsTable.Emit(tw, lbl, extractType(tw, tp.Rhs()))
1658+
if transparentAliases {
1659+
extractTypeWithFlags(tw, tp.Rhs(), true)
1660+
} else {
1661+
kind = dbscheme.TypeAlias.Index()
1662+
dbscheme.TypeNameTable.Emit(tw, lbl, tp.Obj().Name())
1663+
dbscheme.AliasRhsTable.Emit(tw, lbl, extractType(tw, tp.Rhs()))
16471664

1648-
extractTypeObject(tw, lbl, tp.Obj())
1665+
extractTypeObject(tw, lbl, tp.Obj())
1666+
}
16491667
default:
16501668
log.Fatalf("unexpected type %T", tp)
16511669
}
@@ -1665,23 +1683,27 @@ func extractType(tw *trap.Writer, tp types.Type) trap.Label {
16651683
// is constructed from their globally unique ID. This prevents cyclic type keys
16661684
// since type recursion in Go always goes through named types.
16671685
func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
1686+
return getTypeLabelWithFlags(tw, tp, false)
1687+
}
1688+
1689+
func getTypeLabelWithFlags(tw *trap.Writer, tp types.Type, transparentAliases bool) (trap.Label, bool) {
16681690
lbl, exists := tw.Labeler.TypeLabels[tp]
16691691
if !exists {
16701692
switch tp := tp.(type) {
16711693
case *types.Basic:
16721694
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%d;basictype", tp.Kind()))
16731695
case *types.Array:
16741696
len := tp.Len()
1675-
elem := extractType(tw, tp.Elem())
1697+
elem := extractTypeWithFlags(tw, tp.Elem(), transparentAliases)
16761698
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%d,{%s};arraytype", len, elem))
16771699
case *types.Slice:
1678-
elem := extractType(tw, tp.Elem())
1700+
elem := extractTypeWithFlags(tw, tp.Elem(), transparentAliases)
16791701
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%s};slicetype", elem))
16801702
case *types.Struct:
16811703
var b strings.Builder
16821704
for i := 0; i < tp.NumFields(); i++ {
16831705
field := tp.Field(i)
1684-
fieldTypeLbl := extractType(tw, field.Type())
1706+
fieldTypeLbl := extractTypeWithFlags(tw, field.Type(), transparentAliases)
16851707
if i > 0 {
16861708
b.WriteString(",")
16871709
}
@@ -1693,13 +1715,13 @@ func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
16931715
}
16941716
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%s;structtype", b.String()))
16951717
case *types.Pointer:
1696-
base := extractType(tw, tp.Elem())
1718+
base := extractTypeWithFlags(tw, tp.Elem(), transparentAliases)
16971719
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%s};pointertype", base))
16981720
case *types.Interface:
16991721
var b strings.Builder
17001722
for i := 0; i < tp.NumMethods(); i++ {
17011723
meth := tp.Method(i)
1702-
methLbl := extractType(tw, meth.Type())
1724+
methLbl := extractTypeWithFlags(tw, meth.Type(), transparentAliases)
17031725
if i > 0 {
17041726
b.WriteString(",")
17051727
}
@@ -1710,7 +1732,7 @@ func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
17101732
if i > 0 {
17111733
b.WriteString(",")
17121734
}
1713-
fmt.Fprintf(&b, "{%s}", extractType(tw, tp.EmbeddedType(i)))
1735+
fmt.Fprintf(&b, "{%s}", extractTypeWithFlags(tw, tp.EmbeddedType(i), transparentAliases))
17141736
}
17151737
// We note whether the interface is comparable so that we can
17161738
// distinguish the underlying type of `comparable` from an
@@ -1722,7 +1744,7 @@ func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
17221744
case *types.Tuple:
17231745
var b strings.Builder
17241746
for i := 0; i < tp.Len(); i++ {
1725-
compLbl := extractType(tw, tp.At(i).Type())
1747+
compLbl := extractTypeWithFlags(tw, tp.At(i).Type(), transparentAliases)
17261748
if i > 0 {
17271749
b.WriteString(",")
17281750
}
@@ -1734,7 +1756,7 @@ func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
17341756
params, results := tp.Params(), tp.Results()
17351757
if params != nil {
17361758
for i := 0; i < params.Len(); i++ {
1737-
paramLbl := extractType(tw, params.At(i).Type())
1759+
paramLbl := extractTypeWithFlags(tw, params.At(i).Type(), transparentAliases)
17381760
if i > 0 {
17391761
b.WriteString(",")
17401762
}
@@ -1744,7 +1766,7 @@ func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
17441766
b.WriteString(";")
17451767
if results != nil {
17461768
for i := 0; i < results.Len(); i++ {
1747-
resultLbl := extractType(tw, results.At(i).Type())
1769+
resultLbl := extractTypeWithFlags(tw, results.At(i).Type(), transparentAliases)
17481770
if i > 0 {
17491771
b.WriteString(",")
17501772
}
@@ -1756,12 +1778,12 @@ func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
17561778
}
17571779
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%s;signaturetype", b.String()))
17581780
case *types.Map:
1759-
key := extractType(tw, tp.Key())
1760-
value := extractType(tw, tp.Elem())
1781+
key := extractTypeWithFlags(tw, tp.Key(), transparentAliases)
1782+
value := extractTypeWithFlags(tw, tp.Elem(), transparentAliases)
17611783
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%s},{%s};maptype", key, value))
17621784
case *types.Chan:
17631785
dir := tp.Dir()
1764-
elem := extractType(tw, tp.Elem())
1786+
elem := extractTypeWithFlags(tw, tp.Elem(), transparentAliases)
17651787
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%v,{%s};chantype", dir, elem))
17661788
case *types.Named:
17671789
origintp := tp.Origin()
@@ -1779,7 +1801,7 @@ func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
17791801
case *types.Union:
17801802
var b strings.Builder
17811803
for i := 0; i < tp.Len(); i++ {
1782-
compLbl := extractType(tw, tp.Term(i).Type())
1804+
compLbl := extractTypeWithFlags(tw, tp.Term(i).Type(), transparentAliases)
17831805
if i > 0 {
17841806
b.WriteString("|")
17851807
}
@@ -1790,18 +1812,22 @@ func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
17901812
}
17911813
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%s;typesetliteraltype", b.String()))
17921814
case *types.Alias:
1793-
// Ensure that the definition of the aliased type gets extracted
1794-
// (which may be an alias in itself).
1795-
extractType(tw, tp.Rhs())
1815+
if transparentAliases {
1816+
lbl = extractTypeWithFlags(tw, tp.Rhs(), true)
1817+
} else {
1818+
// Ensure that the definition of the aliased type gets extracted
1819+
// (which may be an alias in itself).
1820+
extractType(tw, tp.Rhs())
17961821

1797-
entitylbl, exists := tw.Labeler.LookupObjectID(tp.Obj(), lbl)
1798-
if entitylbl == trap.InvalidLabel {
1799-
panic(fmt.Sprintf("Cannot construct label for alias type %v (underlying object is %v).\n", tp, tp.Obj()))
1800-
}
1801-
if !exists {
1802-
extractObject(tw, tp.Obj(), entitylbl)
1822+
entitylbl, exists := tw.Labeler.LookupObjectID(tp.Obj(), lbl)
1823+
if entitylbl == trap.InvalidLabel {
1824+
panic(fmt.Sprintf("Cannot construct label for alias type %v (underlying object is %v).\n", tp, tp.Obj()))
1825+
}
1826+
if !exists {
1827+
extractObject(tw, tp.Obj(), entitylbl)
1828+
}
1829+
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%s};aliastype", entitylbl))
18031830
}
1804-
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%s};aliastype", entitylbl))
18051831
default:
18061832
log.Fatalf("(getTypeLabel) unexpected type %T", tp)
18071833
}
@@ -1824,29 +1850,29 @@ func extractTypeObject(tw *trap.Writer, lbl trap.Label, entity *types.TypeName)
18241850
}
18251851

18261852
// extractKeyType extracts `key` as the key type of the map type `mp`
1827-
func extractKeyType(tw *trap.Writer, mp trap.Label, key types.Type) {
1828-
dbscheme.KeyTypeTable.Emit(tw, mp, extractType(tw, key))
1853+
func extractKeyType(tw *trap.Writer, mp trap.Label, key types.Type, transparentAliases bool) {
1854+
dbscheme.KeyTypeTable.Emit(tw, mp, extractTypeWithFlags(tw, key, transparentAliases))
18291855
}
18301856

18311857
// extractElementType extracts `element` as the element type of the container type `container`
1832-
func extractElementType(tw *trap.Writer, container trap.Label, element types.Type) {
1833-
dbscheme.ElementTypeTable.Emit(tw, container, extractType(tw, element))
1858+
func extractElementType(tw *trap.Writer, container trap.Label, element types.Type, transparentAliases bool) {
1859+
dbscheme.ElementTypeTable.Emit(tw, container, extractTypeWithFlags(tw, element, transparentAliases))
18341860
}
18351861

18361862
// extractBaseType extracts `base` as the base type of the pointer type `ptr`
1837-
func extractBaseType(tw *trap.Writer, ptr trap.Label, base types.Type) {
1838-
dbscheme.BaseTypeTable.Emit(tw, ptr, extractType(tw, base))
1863+
func extractBaseType(tw *trap.Writer, ptr trap.Label, base types.Type, transparentAliases bool) {
1864+
dbscheme.BaseTypeTable.Emit(tw, ptr, extractTypeWithFlags(tw, base, transparentAliases))
18391865
}
18401866

18411867
// extractUnderlyingType extracts `underlying` as the underlying type of the
18421868
// named type `named`
1843-
func extractUnderlyingType(tw *trap.Writer, named trap.Label, underlying types.Type) {
1844-
dbscheme.UnderlyingTypeTable.Emit(tw, named, extractType(tw, underlying))
1869+
func extractUnderlyingType(tw *trap.Writer, named trap.Label, underlying types.Type, transparentAliases bool) {
1870+
dbscheme.UnderlyingTypeTable.Emit(tw, named, extractTypeWithFlags(tw, underlying, transparentAliases))
18451871
}
18461872

18471873
// extractComponentType extracts `component` as the `idx`th component type of `parent` with name `name`
1848-
func extractComponentType(tw *trap.Writer, parent trap.Label, idx int, name string, component types.Type) {
1849-
dbscheme.ComponentTypesTable.Emit(tw, parent, idx, name, extractType(tw, component))
1874+
func extractComponentType(tw *trap.Writer, parent trap.Label, idx int, name string, component types.Type, transparentAliases bool) {
1875+
dbscheme.ComponentTypesTable.Emit(tw, parent, idx, name, extractTypeWithFlags(tw, component, transparentAliases))
18501876
}
18511877

18521878
// extractNumLines extracts lines-of-code and lines-of-comments information for the

0 commit comments

Comments
 (0)