From c2ebd328979c081dd2c9fd0e359ed99473731d0e Mon Sep 17 00:00:00 2001 From: Naveen Saini Date: Fri, 27 Aug 2021 12:13:00 +0800 Subject: [PATCH 1/2] [X86] When storing v1i1/v2i1/v4i1 to memory, make sure we store zeros in the rest of the byte We can't store garbage in the unused bits. It possible that something like zextload from i1/i2/i4 is created to read the memory. Those zextloads would be legalized assuming the extra bits are 0. I'm not sure that the code in lowerStore is executed for the v1i1/v2i1/v4i1 case. It looks like the DAG combine in combineStore may have converted them to v8i1 first. And I think we're missing some cases to avoid going to the stack in the first place. But I don't have time to investigate those things at the moment so I wanted to focus on the correctness issue. Should fix PR48147. Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D9129 Upstream-Status: Backport Signed-off-by:Craig Topper Signed-off-by: Naveen Saini --- llvm/lib/Target/X86/X86ISelLowering.cpp | 20 ++++++++++++++------ llvm/lib/Target/X86/X86InstrAVX512.td | 2 -- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 63eb050e9b3a..96b5e2cfbd82 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -22688,17 +22688,22 @@ static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 stores. if (StoredVal.getValueType().isVector() && StoredVal.getValueType().getVectorElementType() == MVT::i1) { - assert(StoredVal.getValueType().getVectorNumElements() <= 8 && - "Unexpected VT"); + unsigned NumElts = StoredVal.getValueType().getVectorNumElements(); + assert(NumElts <= 8 && "Unexpected VT"); assert(!St->isTruncatingStore() && "Expected non-truncating store"); assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() && "Expected AVX512F without AVX512DQI"); + // We must pad with zeros to ensure we store zeroes to any unused bits. StoredVal = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1, DAG.getUNDEF(MVT::v16i1), StoredVal, DAG.getIntPtrConstant(0, dl)); StoredVal = DAG.getBitcast(MVT::i16, StoredVal); StoredVal = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, StoredVal); + // Make sure we store zeros in the extra bits. + if (NumElts < 8) + StoredVal = DAG.getZeroExtendInReg(StoredVal, dl, + MVT::getIntegerVT(NumElts)); return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), St->getPointerInfo(), St->getAlignment(), @@ -41585,8 +41590,10 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), VT.getVectorNumElements()); StoredVal = DAG.getBitcast(NewVT, StoredVal); - - return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), + SDValue Val = StoredVal.getOperand(0); + // We must store zeros to the unused bits. + Val = DAG.getZeroExtendInReg(Val, dl, MVT::i1); + return DAG.getStore(St->getChain(), dl, Val, St->getBasePtr(), St->getPointerInfo(), St->getAlignment(), St->getMemOperand()->getFlags()); } @@ -41602,10 +41609,11 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, } // Widen v2i1/v4i1 stores to v8i1. - if ((VT == MVT::v2i1 || VT == MVT::v4i1) && VT == StVT && + if ((VT == MVT::v1i1 || VT == MVT::v2i1 || VT == MVT::v4i1) && VT == StVT && Subtarget.hasAVX512()) { unsigned NumConcats = 8 / VT.getVectorNumElements(); - SmallVector Ops(NumConcats, DAG.getUNDEF(VT)); + // We must store zeros to the unused bits. + SmallVector Ops(NumConcats, DAG.getConstant(0, dl, VT)); Ops[0] = StoredVal; StoredVal = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops); return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(), diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index 32f012033fb0..d3b92183f87b 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -2888,8 +2888,6 @@ def : Pat<(i64 (bitconvert (v64i1 VK64:$src))), // Load/store kreg let Predicates = [HasDQI] in { - def : Pat<(store VK1:$src, addr:$dst), - (KMOVBmk addr:$dst, (COPY_TO_REGCLASS VK1:$src, VK8))>; def : Pat<(v1i1 (load addr:$src)), (COPY_TO_REGCLASS (KMOVBkm addr:$src), VK1)>; -- 2.17.1